[Core] Faster startup for LoRA enabled models (#4634)
This commit is contained in:
parent
5510cf0e8a
commit
ad932a221d
@ -119,6 +119,16 @@ class LoRAModel:
|
||||
self.rank = rank
|
||||
self.loras: Dict[str, LoRALayerWeights] = loras
|
||||
|
||||
def clone(self, lora_model_id: int) -> "LoRAModel":
|
||||
"""Return a copy of the object with different ids.
|
||||
|
||||
Will share the underlying tensors."""
|
||||
return self.__class__(
|
||||
lora_model_id,
|
||||
rank=self.rank,
|
||||
loras=self.loras.copy(),
|
||||
)
|
||||
|
||||
@property
|
||||
def extra_vocab_size(self) -> int:
|
||||
return max(lora.extra_vocab_size
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import Any, Dict, List, Set, Type
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Literal, Set, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -25,6 +26,17 @@ class AbstractWorkerLoRAManager(ABC):
|
||||
self.device = device
|
||||
self.lora_config = lora_config
|
||||
|
||||
# If False, do not cache. If None, cache is empty.
|
||||
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
|
||||
|
||||
@contextmanager
|
||||
def dummy_lora_cache(self):
|
||||
"""Use this context manager to reuse the dummy lora model
|
||||
to avoid creating it repeatedly."""
|
||||
self._cached_dummy_lora = None
|
||||
yield
|
||||
self._cached_dummy_lora = False
|
||||
|
||||
@abstractproperty
|
||||
def is_enabled(self) -> bool:
|
||||
...
|
||||
@ -174,9 +186,15 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
||||
if lora_request.lora_int_id in self.list_loras():
|
||||
return False
|
||||
return self._lora_manager.add_lora(
|
||||
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
|
||||
rank, self.embedding_modules))
|
||||
if isinstance(self._cached_dummy_lora, LoRAModel):
|
||||
dummy_lora = self._cached_dummy_lora.clone(
|
||||
lora_request.lora_int_id)
|
||||
else:
|
||||
dummy_lora = self._lora_manager.create_dummy_lora(
|
||||
lora_request.lora_int_id, rank, self.embedding_modules)
|
||||
if self._cached_dummy_lora is None:
|
||||
self._cached_dummy_lora = dummy_lora
|
||||
return self._lora_manager.add_lora(dummy_lora)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if lora_request.lora_int_id in self.list_loras():
|
||||
|
||||
@ -835,20 +835,21 @@ class ModelRunner:
|
||||
dummy_lora_requests = []
|
||||
dummy_lora_requests_per_seq = []
|
||||
if self.lora_config:
|
||||
for idx in range(self.lora_config.max_loras):
|
||||
lora_id = idx + 1
|
||||
dummy_lora_request = LoRARequest(
|
||||
lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_local_path="/not/a/real/path",
|
||||
)
|
||||
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
||||
rank=LORA_WARMUP_RANK)
|
||||
dummy_lora_requests.append(dummy_lora_request)
|
||||
dummy_lora_requests_per_seq = [
|
||||
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||||
for idx in range(max_num_seqs)
|
||||
]
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
for idx in range(self.lora_config.max_loras):
|
||||
lora_id = idx + 1
|
||||
dummy_lora_request = LoRARequest(
|
||||
lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_local_path="/not/a/real/path",
|
||||
)
|
||||
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
||||
rank=LORA_WARMUP_RANK)
|
||||
dummy_lora_requests.append(dummy_lora_request)
|
||||
dummy_lora_requests_per_seq = [
|
||||
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||||
for idx in range(max_num_seqs)
|
||||
]
|
||||
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user