[LoRA] Add support for pinning lora adapters in the LRU cache (#5603)
This commit is contained in:
parent
7187507301
commit
f5dda63eb5
@ -209,6 +209,34 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
|||||||
assert manager.activate_lora(3)
|
assert manager.activate_lora(3)
|
||||||
assert manager.lora_index_to_id[0] == 2
|
assert manager.lora_index_to_id[0] == 2
|
||||||
assert manager.lora_index_to_id[1] == 3
|
assert manager.lora_index_to_id[1] == 3
|
||||||
|
assert manager.pin_lora(2)
|
||||||
|
assert manager.lora_index_to_id[0] == 2
|
||||||
|
assert manager.lora_index_to_id[1] == 3
|
||||||
|
assert manager.activate_lora(1)
|
||||||
|
assert manager.lora_index_to_id[0] == 2
|
||||||
|
assert manager.lora_index_to_id[1] == 1
|
||||||
|
assert manager.deactivate_lora(2)
|
||||||
|
assert manager.lora_index_to_id[0] is None
|
||||||
|
assert manager.lora_index_to_id[1] == 1
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 1
|
||||||
|
assert manager.pin_lora(3)
|
||||||
|
assert manager.pin_lora(1)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
assert manager.pin_lora(2)
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 1
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
|
||||||
|
assert manager.deactivate_lora(3)
|
||||||
|
assert manager.pin_lora(2)
|
||||||
|
assert manager.lora_index_to_id[0] == 2
|
||||||
|
assert manager.lora_index_to_id[1] == 1
|
||||||
|
assert manager.remove_lora(3)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert manager.pin_lora(3)
|
||||||
|
|
||||||
|
|
||||||
def test_lru_lora_model_manager(dist_init, dummy_model):
|
def test_lru_lora_model_manager(dist_init, dummy_model):
|
||||||
@ -288,6 +316,42 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
|
|||||||
assert set(manager.list_loras()) == set()
|
assert set(manager.list_loras()) == set()
|
||||||
assert all(x is None for x in manager.lora_index_to_id)
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
|
||||||
|
# pinning
|
||||||
|
assert manager.add_lora(model_lora3)
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.add_lora(model_lora4)
|
||||||
|
assert manager.activate_lora(4)
|
||||||
|
assert set(manager.list_loras()) == {3, 4}
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert manager.pin_lora(1)
|
||||||
|
assert manager.pin_lora(3)
|
||||||
|
# Remove manually
|
||||||
|
assert manager.remove_lora(3)
|
||||||
|
assert not manager.remove_lora(3)
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {4}
|
||||||
|
assert manager.lora_index_to_id[0] is None
|
||||||
|
assert manager.lora_index_to_id[1] == 4
|
||||||
|
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
assert manager.pin_lora(1)
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {1, 2}
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
assert manager.remove_oldest_lora()
|
||||||
|
assert set(manager.list_loras()) == {1}
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] is None
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
assert manager.remove_oldest_lora()
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {1}
|
||||||
|
|
||||||
|
|
||||||
def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
||||||
sql_lora_files):
|
sql_lora_files):
|
||||||
|
|||||||
@ -1009,6 +1009,9 @@ class LLMEngine:
|
|||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
return self.model_executor.list_loras()
|
return self.model_executor.list_loras()
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
return self.model_executor.pin_lora(lora_id)
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
self.model_executor.check_health()
|
self.model_executor.check_health()
|
||||||
|
|
||||||
|
|||||||
@ -84,6 +84,9 @@ class CPUExecutor(ExecutorBase):
|
|||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
return self.driver_worker.remove_lora(lora_id)
|
return self.driver_worker.remove_lora(lora_id)
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
return self.driver_worker.pin_lora(lora_id)
|
||||||
|
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
return self.driver_worker.list_loras()
|
return self.driver_worker.list_loras()
|
||||||
|
|
||||||
|
|||||||
@ -100,6 +100,13 @@ class DistributedGPUExecutor(GPUExecutor):
|
|||||||
lora_id=lora_id,
|
lora_id=lora_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
assert lora_id > 0, "lora_id must be greater than 0."
|
||||||
|
return self._run_workers(
|
||||||
|
"pin_lora",
|
||||||
|
lora_id=lora_id,
|
||||||
|
)
|
||||||
|
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
return self._run_workers("list_loras")
|
return self._run_workers("list_loras")
|
||||||
|
|
||||||
|
|||||||
@ -86,6 +86,10 @@ class ExecutorBase(ABC):
|
|||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
raise NotImplementedError # type: ignore
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -99,6 +99,10 @@ class GPUExecutor(ExecutorBase):
|
|||||||
assert lora_id > 0, "lora_id must be greater than 0."
|
assert lora_id > 0, "lora_id must be greater than 0."
|
||||||
return self.driver_worker.remove_lora(lora_id)
|
return self.driver_worker.remove_lora(lora_id)
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
assert lora_id > 0, "lora_id must be greater than 0."
|
||||||
|
return self.driver_worker.pin_lora(lora_id)
|
||||||
|
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
return self.driver_worker.list_loras()
|
return self.driver_worker.list_loras()
|
||||||
|
|
||||||
|
|||||||
@ -65,6 +65,9 @@ class NeuronExecutor(ExecutorBase):
|
|||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
return self.driver_worker.remove_lora(lora_id)
|
return self.driver_worker.remove_lora(lora_id)
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
return self.driver_worker.pin_lora(lora_id)
|
||||||
|
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
return self.driver_worker.list_loras()
|
return self.driver_worker.list_loras()
|
||||||
|
|
||||||
|
|||||||
@ -525,6 +525,12 @@ class LoRAModelManager:
|
|||||||
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
|
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
|
||||||
return bool(self._registered_loras.pop(lora_id, None))
|
return bool(self._registered_loras.pop(lora_id, None))
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
"""Pin a LoRAModel in the manager cache."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Pinning is not supported in LoRAModelManager."
|
||||||
|
"Use LRUCacheLoRAModelManager for pinning") # type: ignore
|
||||||
|
|
||||||
# TODO see if this can be vectorized
|
# TODO see if this can be vectorized
|
||||||
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
|
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
|
||||||
(base_indices, sampler_indices, sampler_indices_padded,
|
(base_indices, sampler_indices, sampler_indices_padded,
|
||||||
@ -777,6 +783,26 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
"""Pin a LoRAModel in the manager cache."""
|
||||||
|
self._pin_lora_in_cpu_cache(lora_id)
|
||||||
|
self._pin_lora_in_gpu_cache(lora_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _pin_lora_in_cpu_cache(self, lora_id: int):
|
||||||
|
try:
|
||||||
|
self._registered_loras.pin(lora_id)
|
||||||
|
except ValueError as err:
|
||||||
|
raise ValueError("Pinning failed. "
|
||||||
|
f"LoRA {lora_id} is not registered.") from err
|
||||||
|
|
||||||
|
def _pin_lora_in_gpu_cache(self, lora_id: int):
|
||||||
|
if lora_id not in self._active_loras:
|
||||||
|
# move lora to gpu if not already active
|
||||||
|
self.activate_lora(lora_id)
|
||||||
|
|
||||||
|
self._active_loras.pin(lora_id)
|
||||||
|
|
||||||
|
|
||||||
def create_lora_manager(
|
def create_lora_manager(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
|||||||
@ -221,6 +221,9 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
|||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
return self._lora_manager.remove_lora(lora_id)
|
return self._lora_manager.remove_lora(lora_id)
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
return self._lora_manager.pin_lora(lora_id)
|
||||||
|
|
||||||
def remove_all_loras(self):
|
def remove_all_loras(self):
|
||||||
self._lora_manager.remove_all_loras()
|
self._lora_manager.remove_all_loras()
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from collections import defaultdict
|
|||||||
from functools import lru_cache, partial, wraps
|
from functools import lru_cache, partial, wraps
|
||||||
from platform import uname
|
from platform import uname
|
||||||
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
|
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
|
||||||
Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
|
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
|
||||||
Union)
|
Union)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -44,6 +44,13 @@ K = TypeVar("K")
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class _Sentinel:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
ALL_PINNED_SENTINEL = _Sentinel()
|
||||||
|
|
||||||
|
|
||||||
class Device(enum.Enum):
|
class Device(enum.Enum):
|
||||||
GPU = enum.auto()
|
GPU = enum.auto()
|
||||||
CPU = enum.auto()
|
CPU = enum.auto()
|
||||||
@ -67,6 +74,7 @@ class LRUCache(Generic[T]):
|
|||||||
|
|
||||||
def __init__(self, capacity: int):
|
def __init__(self, capacity: int):
|
||||||
self.cache: OrderedDict[Hashable, T] = OrderedDict()
|
self.cache: OrderedDict[Hashable, T] = OrderedDict()
|
||||||
|
self.pinned_items: Set[Hashable] = set()
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
|
|
||||||
def __contains__(self, key: Hashable) -> bool:
|
def __contains__(self, key: Hashable) -> bool:
|
||||||
@ -102,14 +110,36 @@ class LRUCache(Generic[T]):
|
|||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
self._remove_old_if_needed()
|
self._remove_old_if_needed()
|
||||||
|
|
||||||
|
def pin(self, key: Hashable) -> None:
|
||||||
|
"""
|
||||||
|
Pins a key in the cache preventing it from being
|
||||||
|
evicted in the LRU order.
|
||||||
|
"""
|
||||||
|
if key not in self.cache:
|
||||||
|
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
||||||
|
self.pinned_items.add(key)
|
||||||
|
|
||||||
|
def _unpin(self, key: Hashable) -> None:
|
||||||
|
self.pinned_items.remove(key)
|
||||||
|
|
||||||
def _on_remove(self, key: Hashable, value: Optional[T]):
|
def _on_remove(self, key: Hashable, value: Optional[T]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def remove_oldest(self):
|
def remove_oldest(self, remove_pinned=False):
|
||||||
if not self.cache:
|
if not self.cache:
|
||||||
return
|
return
|
||||||
key, value = self.cache.popitem(last=False)
|
|
||||||
self._on_remove(key, value)
|
if not remove_pinned:
|
||||||
|
# pop the oldest item in the cache that is not pinned
|
||||||
|
lru_key = next(
|
||||||
|
(key for key in self.cache if key not in self.pinned_items),
|
||||||
|
ALL_PINNED_SENTINEL)
|
||||||
|
if lru_key is ALL_PINNED_SENTINEL:
|
||||||
|
raise RuntimeError("All items are pinned, "
|
||||||
|
"cannot remove oldest from the cache.")
|
||||||
|
else:
|
||||||
|
lru_key = next(iter(self.cache))
|
||||||
|
self.pop(lru_key)
|
||||||
|
|
||||||
def _remove_old_if_needed(self) -> None:
|
def _remove_old_if_needed(self) -> None:
|
||||||
while len(self.cache) > self.capacity:
|
while len(self.cache) > self.capacity:
|
||||||
@ -120,13 +150,16 @@ class LRUCache(Generic[T]):
|
|||||||
default_value: Optional[T] = None) -> Optional[T]:
|
default_value: Optional[T] = None) -> Optional[T]:
|
||||||
run_on_remove = key in self.cache
|
run_on_remove = key in self.cache
|
||||||
value: Optional[T] = self.cache.pop(key, default_value)
|
value: Optional[T] = self.cache.pop(key, default_value)
|
||||||
|
# remove from pinned items
|
||||||
|
if key in self.pinned_items:
|
||||||
|
self._unpin(key)
|
||||||
if run_on_remove:
|
if run_on_remove:
|
||||||
self._on_remove(key, value)
|
self._on_remove(key, value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
while len(self.cache) > 0:
|
while len(self.cache) > 0:
|
||||||
self.remove_oldest()
|
self.remove_oldest(remove_pinned=True)
|
||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -878,6 +878,11 @@ class ModelRunner:
|
|||||||
raise RuntimeError("LoRA is not enabled.")
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
return self.lora_manager.remove_lora(lora_id)
|
return self.lora_manager.remove_lora(lora_id)
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.pin_lora(lora_id)
|
||||||
|
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
if not self.lora_manager:
|
if not self.lora_manager:
|
||||||
raise RuntimeError("LoRA is not enabled.")
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
|||||||
@ -333,6 +333,9 @@ class Worker(WorkerBase):
|
|||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
return self.model_runner.remove_lora(lora_id)
|
return self.model_runner.remove_lora(lora_id)
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
return self.model_runner.pin_lora(lora_id)
|
||||||
|
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
return self.model_runner.list_loras()
|
return self.model_runner.list_loras()
|
||||||
|
|
||||||
|
|||||||
@ -70,6 +70,10 @@ class WorkerBase(ABC):
|
|||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -86,6 +90,10 @@ class LoraNotSupportedWorkerBase(WorkerBase):
|
|||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
raise ValueError(f"{type(self)} does not support LoRA")
|
raise ValueError(f"{type(self)} does not support LoRA")
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
return ValueError(
|
||||||
|
f"{type(self)} does not support LoRA") # type: ignore
|
||||||
|
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
raise ValueError(f"{type(self)} does not support LoRA")
|
raise ValueError(f"{type(self)} does not support LoRA")
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user