[Core] Faster startup for LoRA enabled models (#4634)

This commit is contained in:
Antoni Baum 2024-05-08 10:33:18 -07:00 committed by GitHub
parent 5510cf0e8a
commit ad932a221d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 18 deletions

View File

@ -119,6 +119,16 @@ class LoRAModel:
self.rank = rank self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras self.loras: Dict[str, LoRALayerWeights] = loras
def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.
Will share the underlying tensors."""
return self.__class__(
lora_model_id,
rank=self.rank,
loras=self.loras.copy(),
)
@property @property
def extra_vocab_size(self) -> int: def extra_vocab_size(self) -> int:
return max(lora.extra_vocab_size return max(lora.extra_vocab_size

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, List, Set, Type from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Set, Type, Union
import torch import torch
@ -25,6 +26,17 @@ class AbstractWorkerLoRAManager(ABC):
self.device = device self.device = device
self.lora_config = lora_config self.lora_config = lora_config
# If False, do not cache. If None, cache is empty.
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False
@abstractproperty @abstractproperty
def is_enabled(self) -> bool: def is_enabled(self) -> bool:
... ...
@ -174,9 +186,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras(): if lora_request.lora_int_id in self.list_loras():
return False return False
return self._lora_manager.add_lora( if isinstance(self._cached_dummy_lora, LoRAModel):
self._lora_manager.create_dummy_lora(lora_request.lora_int_id, dummy_lora = self._cached_dummy_lora.clone(
rank, self.embedding_modules)) lora_request.lora_int_id)
else:
dummy_lora = self._lora_manager.create_dummy_lora(
lora_request.lora_int_id, rank, self.embedding_modules)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._lora_manager.add_lora(dummy_lora)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras(): if lora_request.lora_int_id in self.list_loras():

View File

@ -835,20 +835,21 @@ class ModelRunner:
dummy_lora_requests = [] dummy_lora_requests = []
dummy_lora_requests_per_seq = [] dummy_lora_requests_per_seq = []
if self.lora_config: if self.lora_config:
for idx in range(self.lora_config.max_loras): with self.lora_manager.dummy_lora_cache():
lora_id = idx + 1 for idx in range(self.lora_config.max_loras):
dummy_lora_request = LoRARequest( lora_id = idx + 1
lora_name=f"warmup_{lora_id}", dummy_lora_request = LoRARequest(
lora_int_id=lora_id, lora_name=f"warmup_{lora_id}",
lora_local_path="/not/a/real/path", lora_int_id=lora_id,
) lora_local_path="/not/a/real/path",
self.lora_manager.add_dummy_lora(dummy_lora_request, )
rank=LORA_WARMUP_RANK) self.lora_manager.add_dummy_lora(dummy_lora_request,
dummy_lora_requests.append(dummy_lora_request) rank=LORA_WARMUP_RANK)
dummy_lora_requests_per_seq = [ dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests[idx % len(dummy_lora_requests)] dummy_lora_requests_per_seq = [
for idx in range(max_num_seqs) dummy_lora_requests[idx % len(dummy_lora_requests)]
] for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the total # Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.