From ad932a221d2a4c1e6355021bb9e9c47f7a179e51 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 8 May 2024 10:33:18 -0700 Subject: [PATCH] [Core] Faster startup for LoRA enabled models (#4634) --- vllm/lora/models.py | 10 ++++++++++ vllm/lora/worker_manager.py | 26 ++++++++++++++++++++++---- vllm/worker/model_runner.py | 29 +++++++++++++++-------------- 3 files changed, 47 insertions(+), 18 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 50d7e913..cd45040b 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -119,6 +119,16 @@ class LoRAModel: self.rank = rank self.loras: Dict[str, LoRALayerWeights] = loras + def clone(self, lora_model_id: int) -> "LoRAModel": + """Return a copy of the object with different ids. + + Will share the underlying tensors.""" + return self.__class__( + lora_model_id, + rank=self.rank, + loras=self.loras.copy(), + ) + @property def extra_vocab_size(self) -> int: return max(lora.extra_vocab_size diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index ec3c10c5..377f561c 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod, abstractproperty -from typing import Any, Dict, List, Set, Type +from contextlib import contextmanager +from typing import Any, Dict, List, Literal, Set, Type, Union import torch @@ -25,6 +26,17 @@ class AbstractWorkerLoRAManager(ABC): self.device = device self.lora_config = lora_config + # If False, do not cache. If None, cache is empty. + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False + @abstractproperty def is_enabled(self) -> bool: ... @@ -174,9 +186,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: if lora_request.lora_int_id in self.list_loras(): return False - return self._lora_manager.add_lora( - self._lora_manager.create_dummy_lora(lora_request.lora_int_id, - rank, self.embedding_modules)) + if isinstance(self._cached_dummy_lora, LoRAModel): + dummy_lora = self._cached_dummy_lora.clone( + lora_request.lora_int_id) + else: + dummy_lora = self._lora_manager.create_dummy_lora( + lora_request.lora_int_id, rank, self.embedding_modules) + if self._cached_dummy_lora is None: + self._cached_dummy_lora = dummy_lora + return self._lora_manager.add_lora(dummy_lora) def add_lora(self, lora_request: LoRARequest) -> bool: if lora_request.lora_int_id in self.list_loras(): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c96f13c5..46c67306 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -835,20 +835,21 @@ class ModelRunner: dummy_lora_requests = [] dummy_lora_requests_per_seq = [] if self.lora_config: - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_local_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens.