From 85de0934727dc2c7b740b1d4a90d1a2e3c2d0585 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 29 Jun 2023 15:00:21 -0700 Subject: [PATCH] [Fix] Do not pin memory when in WSL (#312) --- vllm/utils.py | 5 +++++ vllm/worker/cache_engine.py | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 85fe1877..eb686b64 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,4 +1,5 @@ import enum +from platform import uname import uuid import psutil @@ -36,3 +37,7 @@ def get_cpu_memory() -> int: def random_uuid() -> str: return str(uuid.uuid4().hex) + +def in_wsl() -> bool: + # Reference: https://github.com/microsoft/WSL/issues/4071 + return "microsoft" in " ".join(uname()).lower() diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 30b4ec7d..eb99bf75 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -5,6 +5,10 @@ import torch from vllm import cache_ops from vllm.config import CacheConfig, ModelConfig, ParallelConfig +from vllm.logger import init_logger +from vllm.utils import in_wsl + +logger = init_logger(__name__) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -85,16 +89,22 @@ class CacheEngine: cpu_cache: List[KVCache] = [] key_block_shape = self.get_key_block_shape() value_block_shape = self.get_value_block_shape() + pin_memory = not in_wsl() + if not pin_memory: + # Pinning memory in WSL is not supported. + # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications + logger.warn("Using 'pin_memory=False' as WSL is detected. " + "This may slow down the performance.") for _ in range(self.num_layers): key_blocks = torch.empty( size=(self.num_cpu_blocks, *key_block_shape), dtype=self.dtype, - pin_memory=True, + pin_memory=pin_memory, ) value_blocks = torch.empty( size=(self.num_cpu_blocks, *value_block_shape), dtype=self.dtype, - pin_memory=True, + pin_memory=pin_memory, ) cpu_cache.append((key_blocks, value_blocks)) return cpu_cache