From 5265631d15d59735152c8b72b38d960110987f10 Mon Sep 17 00:00:00 2001 From: Vladimir Date: Fri, 26 Jan 2024 08:48:17 +0100 Subject: [PATCH] use a correct device when creating OptionalCUDAGuard (#2583) --- csrc/cache_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 9f173534..b7523cb4 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -34,7 +34,7 @@ void swap_blocks( char *dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); - const at::cuda::OptionalCUDAGuard device_guard(src_device); + const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. for (const auto& pair : block_mapping) {