[Core] Faster startup for LoRA enabled models (#4634)
This commit is contained in:
parent
5510cf0e8a
commit
ad932a221d
@ -119,6 +119,16 @@ class LoRAModel:
|
|||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.loras: Dict[str, LoRALayerWeights] = loras
|
self.loras: Dict[str, LoRALayerWeights] = loras
|
||||||
|
|
||||||
|
def clone(self, lora_model_id: int) -> "LoRAModel":
|
||||||
|
"""Return a copy of the object with different ids.
|
||||||
|
|
||||||
|
Will share the underlying tensors."""
|
||||||
|
return self.__class__(
|
||||||
|
lora_model_id,
|
||||||
|
rank=self.rank,
|
||||||
|
loras=self.loras.copy(),
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def extra_vocab_size(self) -> int:
|
def extra_vocab_size(self) -> int:
|
||||||
return max(lora.extra_vocab_size
|
return max(lora.extra_vocab_size
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod, abstractproperty
|
from abc import ABC, abstractmethod, abstractproperty
|
||||||
from typing import Any, Dict, List, Set, Type
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Dict, List, Literal, Set, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -25,6 +26,17 @@ class AbstractWorkerLoRAManager(ABC):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
# If False, do not cache. If None, cache is empty.
|
||||||
|
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def dummy_lora_cache(self):
|
||||||
|
"""Use this context manager to reuse the dummy lora model
|
||||||
|
to avoid creating it repeatedly."""
|
||||||
|
self._cached_dummy_lora = None
|
||||||
|
yield
|
||||||
|
self._cached_dummy_lora = False
|
||||||
|
|
||||||
@abstractproperty
|
@abstractproperty
|
||||||
def is_enabled(self) -> bool:
|
def is_enabled(self) -> bool:
|
||||||
...
|
...
|
||||||
@ -174,9 +186,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
|||||||
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
||||||
if lora_request.lora_int_id in self.list_loras():
|
if lora_request.lora_int_id in self.list_loras():
|
||||||
return False
|
return False
|
||||||
return self._lora_manager.add_lora(
|
if isinstance(self._cached_dummy_lora, LoRAModel):
|
||||||
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
|
dummy_lora = self._cached_dummy_lora.clone(
|
||||||
rank, self.embedding_modules))
|
lora_request.lora_int_id)
|
||||||
|
else:
|
||||||
|
dummy_lora = self._lora_manager.create_dummy_lora(
|
||||||
|
lora_request.lora_int_id, rank, self.embedding_modules)
|
||||||
|
if self._cached_dummy_lora is None:
|
||||||
|
self._cached_dummy_lora = dummy_lora
|
||||||
|
return self._lora_manager.add_lora(dummy_lora)
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
if lora_request.lora_int_id in self.list_loras():
|
if lora_request.lora_int_id in self.list_loras():
|
||||||
|
|||||||
@ -835,20 +835,21 @@ class ModelRunner:
|
|||||||
dummy_lora_requests = []
|
dummy_lora_requests = []
|
||||||
dummy_lora_requests_per_seq = []
|
dummy_lora_requests_per_seq = []
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
for idx in range(self.lora_config.max_loras):
|
with self.lora_manager.dummy_lora_cache():
|
||||||
lora_id = idx + 1
|
for idx in range(self.lora_config.max_loras):
|
||||||
dummy_lora_request = LoRARequest(
|
lora_id = idx + 1
|
||||||
lora_name=f"warmup_{lora_id}",
|
dummy_lora_request = LoRARequest(
|
||||||
lora_int_id=lora_id,
|
lora_name=f"warmup_{lora_id}",
|
||||||
lora_local_path="/not/a/real/path",
|
lora_int_id=lora_id,
|
||||||
)
|
lora_local_path="/not/a/real/path",
|
||||||
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
)
|
||||||
rank=LORA_WARMUP_RANK)
|
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
||||||
dummy_lora_requests.append(dummy_lora_request)
|
rank=LORA_WARMUP_RANK)
|
||||||
dummy_lora_requests_per_seq = [
|
dummy_lora_requests.append(dummy_lora_request)
|
||||||
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
dummy_lora_requests_per_seq = [
|
||||||
for idx in range(max_num_seqs)
|
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||||||
]
|
for idx in range(max_num_seqs)
|
||||||
|
]
|
||||||
|
|
||||||
# Profile memory usage with max_num_sequences sequences and the total
|
# Profile memory usage with max_num_sequences sequences and the total
|
||||||
# number of tokens equal to max_num_batched_tokens.
|
# number of tokens equal to max_num_batched_tokens.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user