[LoRA] Add support for pinning lora adapters in the LRU cache (#5603)
This commit is contained in:
parent
7187507301
commit
f5dda63eb5
@ -209,6 +209,34 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
||||
assert manager.activate_lora(3)
|
||||
assert manager.lora_index_to_id[0] == 2
|
||||
assert manager.lora_index_to_id[1] == 3
|
||||
assert manager.pin_lora(2)
|
||||
assert manager.lora_index_to_id[0] == 2
|
||||
assert manager.lora_index_to_id[1] == 3
|
||||
assert manager.activate_lora(1)
|
||||
assert manager.lora_index_to_id[0] == 2
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
assert manager.deactivate_lora(2)
|
||||
assert manager.lora_index_to_id[0] is None
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
assert manager.activate_lora(3)
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
assert manager.pin_lora(3)
|
||||
assert manager.pin_lora(1)
|
||||
with pytest.raises(RuntimeError):
|
||||
assert manager.pin_lora(2)
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
with pytest.raises(RuntimeError):
|
||||
assert manager.activate_lora(2)
|
||||
|
||||
assert manager.deactivate_lora(3)
|
||||
assert manager.pin_lora(2)
|
||||
assert manager.lora_index_to_id[0] == 2
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
assert manager.remove_lora(3)
|
||||
with pytest.raises(ValueError):
|
||||
assert manager.pin_lora(3)
|
||||
|
||||
|
||||
def test_lru_lora_model_manager(dist_init, dummy_model):
|
||||
@ -288,6 +316,42 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
|
||||
assert set(manager.list_loras()) == set()
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
|
||||
# pinning
|
||||
assert manager.add_lora(model_lora3)
|
||||
assert manager.activate_lora(3)
|
||||
assert manager.add_lora(model_lora4)
|
||||
assert manager.activate_lora(4)
|
||||
assert set(manager.list_loras()) == {3, 4}
|
||||
with pytest.raises(ValueError):
|
||||
assert manager.pin_lora(1)
|
||||
assert manager.pin_lora(3)
|
||||
# Remove manually
|
||||
assert manager.remove_lora(3)
|
||||
assert not manager.remove_lora(3)
|
||||
|
||||
assert set(manager.list_loras()) == {4}
|
||||
assert manager.lora_index_to_id[0] is None
|
||||
assert manager.lora_index_to_id[1] == 4
|
||||
|
||||
assert manager.add_lora(model_lora1)
|
||||
assert manager.pin_lora(1)
|
||||
assert manager.add_lora(model_lora2)
|
||||
assert manager.activate_lora(2)
|
||||
|
||||
assert set(manager.list_loras()) == {1, 2}
|
||||
assert manager.lora_index_to_id[0] == 1
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
|
||||
assert manager.remove_oldest_lora()
|
||||
assert set(manager.list_loras()) == {1}
|
||||
assert manager.lora_index_to_id[0] == 1
|
||||
assert manager.lora_index_to_id[1] is None
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
assert manager.remove_oldest_lora()
|
||||
|
||||
assert set(manager.list_loras()) == {1}
|
||||
|
||||
|
||||
def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files):
|
||||
|
||||
@ -1009,6 +1009,9 @@ class LLMEngine:
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.model_executor.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_executor.pin_lora(lora_id)
|
||||
|
||||
def check_health(self) -> None:
|
||||
self.model_executor.check_health()
|
||||
|
||||
|
||||
@ -84,6 +84,9 @@ class CPUExecutor(ExecutorBase):
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.driver_worker.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
|
||||
@ -100,6 +100,13 @@ class DistributedGPUExecutor(GPUExecutor):
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self._run_workers(
|
||||
"pin_lora",
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self._run_workers("list_loras")
|
||||
|
||||
|
||||
@ -86,6 +86,10 @@ class ExecutorBase(ABC):
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError # type: ignore
|
||||
|
||||
@abstractmethod
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -99,6 +99,10 @@ class GPUExecutor(ExecutorBase):
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self.driver_worker.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
|
||||
@ -65,6 +65,9 @@ class NeuronExecutor(ExecutorBase):
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.driver_worker.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
|
||||
@ -525,6 +525,12 @@ class LoRAModelManager:
|
||||
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
|
||||
return bool(self._registered_loras.pop(lora_id, None))
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Pin a LoRAModel in the manager cache."""
|
||||
raise NotImplementedError(
|
||||
"Pinning is not supported in LoRAModelManager."
|
||||
"Use LRUCacheLoRAModelManager for pinning") # type: ignore
|
||||
|
||||
# TODO see if this can be vectorized
|
||||
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
|
||||
(base_indices, sampler_indices, sampler_indices_padded,
|
||||
@ -777,6 +783,26 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
return True
|
||||
return False
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Pin a LoRAModel in the manager cache."""
|
||||
self._pin_lora_in_cpu_cache(lora_id)
|
||||
self._pin_lora_in_gpu_cache(lora_id)
|
||||
return True
|
||||
|
||||
def _pin_lora_in_cpu_cache(self, lora_id: int):
|
||||
try:
|
||||
self._registered_loras.pin(lora_id)
|
||||
except ValueError as err:
|
||||
raise ValueError("Pinning failed. "
|
||||
f"LoRA {lora_id} is not registered.") from err
|
||||
|
||||
def _pin_lora_in_gpu_cache(self, lora_id: int):
|
||||
if lora_id not in self._active_loras:
|
||||
# move lora to gpu if not already active
|
||||
self.activate_lora(lora_id)
|
||||
|
||||
self._active_loras.pin(lora_id)
|
||||
|
||||
|
||||
def create_lora_manager(
|
||||
model: nn.Module,
|
||||
|
||||
@ -221,6 +221,9 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self._lora_manager.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self._lora_manager.pin_lora(lora_id)
|
||||
|
||||
def remove_all_loras(self):
|
||||
self._lora_manager.remove_all_loras()
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ from collections import defaultdict
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
|
||||
Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
|
||||
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
|
||||
Union)
|
||||
|
||||
import numpy as np
|
||||
@ -44,6 +44,13 @@ K = TypeVar("K")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _Sentinel:
|
||||
...
|
||||
|
||||
|
||||
ALL_PINNED_SENTINEL = _Sentinel()
|
||||
|
||||
|
||||
class Device(enum.Enum):
|
||||
GPU = enum.auto()
|
||||
CPU = enum.auto()
|
||||
@ -67,6 +74,7 @@ class LRUCache(Generic[T]):
|
||||
|
||||
def __init__(self, capacity: int):
|
||||
self.cache: OrderedDict[Hashable, T] = OrderedDict()
|
||||
self.pinned_items: Set[Hashable] = set()
|
||||
self.capacity = capacity
|
||||
|
||||
def __contains__(self, key: Hashable) -> bool:
|
||||
@ -102,14 +110,36 @@ class LRUCache(Generic[T]):
|
||||
self.cache.move_to_end(key)
|
||||
self._remove_old_if_needed()
|
||||
|
||||
def pin(self, key: Hashable) -> None:
|
||||
"""
|
||||
Pins a key in the cache preventing it from being
|
||||
evicted in the LRU order.
|
||||
"""
|
||||
if key not in self.cache:
|
||||
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
||||
self.pinned_items.add(key)
|
||||
|
||||
def _unpin(self, key: Hashable) -> None:
|
||||
self.pinned_items.remove(key)
|
||||
|
||||
def _on_remove(self, key: Hashable, value: Optional[T]):
|
||||
pass
|
||||
|
||||
def remove_oldest(self):
|
||||
def remove_oldest(self, remove_pinned=False):
|
||||
if not self.cache:
|
||||
return
|
||||
key, value = self.cache.popitem(last=False)
|
||||
self._on_remove(key, value)
|
||||
|
||||
if not remove_pinned:
|
||||
# pop the oldest item in the cache that is not pinned
|
||||
lru_key = next(
|
||||
(key for key in self.cache if key not in self.pinned_items),
|
||||
ALL_PINNED_SENTINEL)
|
||||
if lru_key is ALL_PINNED_SENTINEL:
|
||||
raise RuntimeError("All items are pinned, "
|
||||
"cannot remove oldest from the cache.")
|
||||
else:
|
||||
lru_key = next(iter(self.cache))
|
||||
self.pop(lru_key)
|
||||
|
||||
def _remove_old_if_needed(self) -> None:
|
||||
while len(self.cache) > self.capacity:
|
||||
@ -120,13 +150,16 @@ class LRUCache(Generic[T]):
|
||||
default_value: Optional[T] = None) -> Optional[T]:
|
||||
run_on_remove = key in self.cache
|
||||
value: Optional[T] = self.cache.pop(key, default_value)
|
||||
# remove from pinned items
|
||||
if key in self.pinned_items:
|
||||
self._unpin(key)
|
||||
if run_on_remove:
|
||||
self._on_remove(key, value)
|
||||
return value
|
||||
|
||||
def clear(self):
|
||||
while len(self.cache) > 0:
|
||||
self.remove_oldest()
|
||||
self.remove_oldest(remove_pinned=True)
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
|
||||
@ -878,6 +878,11 @@ class ModelRunner:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
|
||||
@ -333,6 +333,9 @@ class Worker(WorkerBase):
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
|
||||
@ -70,6 +70,10 @@ class WorkerBase(ABC):
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError
|
||||
@ -86,6 +90,10 @@ class LoraNotSupportedWorkerBase(WorkerBase):
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return ValueError(
|
||||
f"{type(self)} does not support LoRA") # type: ignore
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user