[LoRA] Add support for pinning lora adapters in the LRU cache (#5603)

This commit is contained in:
rohithkrn 2024-06-21 15:42:46 -07:00 committed by GitHub
parent 7187507301
commit f5dda63eb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 171 additions and 5 deletions

View File

@ -209,6 +209,34 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
assert manager.activate_lora(3) assert manager.activate_lora(3)
assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3 assert manager.lora_index_to_id[1] == 3
assert manager.pin_lora(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.activate_lora(1)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.deactivate_lora(2)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 1
assert manager.activate_lora(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
assert manager.pin_lora(3)
assert manager.pin_lora(1)
with pytest.raises(RuntimeError):
assert manager.pin_lora(2)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
with pytest.raises(RuntimeError):
assert manager.activate_lora(2)
assert manager.deactivate_lora(3)
assert manager.pin_lora(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.remove_lora(3)
with pytest.raises(ValueError):
assert manager.pin_lora(3)
def test_lru_lora_model_manager(dist_init, dummy_model): def test_lru_lora_model_manager(dist_init, dummy_model):
@ -288,6 +316,42 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
assert set(manager.list_loras()) == set() assert set(manager.list_loras()) == set()
assert all(x is None for x in manager.lora_index_to_id) assert all(x is None for x in manager.lora_index_to_id)
# pinning
assert manager.add_lora(model_lora3)
assert manager.activate_lora(3)
assert manager.add_lora(model_lora4)
assert manager.activate_lora(4)
assert set(manager.list_loras()) == {3, 4}
with pytest.raises(ValueError):
assert manager.pin_lora(1)
assert manager.pin_lora(3)
# Remove manually
assert manager.remove_lora(3)
assert not manager.remove_lora(3)
assert set(manager.list_loras()) == {4}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 4
assert manager.add_lora(model_lora1)
assert manager.pin_lora(1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)
assert set(manager.list_loras()) == {1, 2}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == {1}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] is None
with pytest.raises(RuntimeError):
assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == {1}
def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files): sql_lora_files):

View File

@ -1009,6 +1009,9 @@ class LLMEngine:
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
return self.model_executor.list_loras() return self.model_executor.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def check_health(self) -> None: def check_health(self) -> None:
self.model_executor.check_health() self.model_executor.check_health()

View File

@ -84,6 +84,9 @@ class CPUExecutor(ExecutorBase):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id) return self.driver_worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras() return self.driver_worker.list_loras()

View File

@ -100,6 +100,13 @@ class DistributedGPUExecutor(GPUExecutor):
lora_id=lora_id, lora_id=lora_id,
) )
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"pin_lora",
lora_id=lora_id,
)
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
return self._run_workers("list_loras") return self._run_workers("list_loras")

View File

@ -86,6 +86,10 @@ class ExecutorBase(ABC):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError # type: ignore
@abstractmethod @abstractmethod
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
raise NotImplementedError raise NotImplementedError

View File

@ -99,6 +99,10 @@ class GPUExecutor(ExecutorBase):
assert lora_id > 0, "lora_id must be greater than 0." assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id) return self.driver_worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras() return self.driver_worker.list_loras()

View File

@ -65,6 +65,9 @@ class NeuronExecutor(ExecutorBase):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id) return self.driver_worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras() return self.driver_worker.list_loras()

View File

@ -525,6 +525,12 @@ class LoRAModelManager:
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None) self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
return bool(self._registered_loras.pop(lora_id, None)) return bool(self._registered_loras.pop(lora_id, None))
def pin_lora(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
raise NotImplementedError(
"Pinning is not supported in LoRAModelManager."
"Use LRUCacheLoRAModelManager for pinning") # type: ignore
# TODO see if this can be vectorized # TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None: def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded, (base_indices, sampler_indices, sampler_indices_padded,
@ -777,6 +783,26 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
return True return True
return False return False
def pin_lora(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
self._pin_lora_in_cpu_cache(lora_id)
self._pin_lora_in_gpu_cache(lora_id)
return True
def _pin_lora_in_cpu_cache(self, lora_id: int):
try:
self._registered_loras.pin(lora_id)
except ValueError as err:
raise ValueError("Pinning failed. "
f"LoRA {lora_id} is not registered.") from err
def _pin_lora_in_gpu_cache(self, lora_id: int):
if lora_id not in self._active_loras:
# move lora to gpu if not already active
self.activate_lora(lora_id)
self._active_loras.pin(lora_id)
def create_lora_manager( def create_lora_manager(
model: nn.Module, model: nn.Module,

View File

@ -221,6 +221,9 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id) return self._lora_manager.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self._lora_manager.pin_lora(lora_id)
def remove_all_loras(self): def remove_all_loras(self):
self._lora_manager.remove_all_loras() self._lora_manager.remove_all_loras()

View File

@ -15,7 +15,7 @@ from collections import defaultdict
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from platform import uname from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Tuple, TypeVar, Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
Union) Union)
import numpy as np import numpy as np
@ -44,6 +44,13 @@ K = TypeVar("K")
T = TypeVar("T") T = TypeVar("T")
class _Sentinel:
...
ALL_PINNED_SENTINEL = _Sentinel()
class Device(enum.Enum): class Device(enum.Enum):
GPU = enum.auto() GPU = enum.auto()
CPU = enum.auto() CPU = enum.auto()
@ -67,6 +74,7 @@ class LRUCache(Generic[T]):
def __init__(self, capacity: int): def __init__(self, capacity: int):
self.cache: OrderedDict[Hashable, T] = OrderedDict() self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.pinned_items: Set[Hashable] = set()
self.capacity = capacity self.capacity = capacity
def __contains__(self, key: Hashable) -> bool: def __contains__(self, key: Hashable) -> bool:
@ -102,14 +110,36 @@ class LRUCache(Generic[T]):
self.cache.move_to_end(key) self.cache.move_to_end(key)
self._remove_old_if_needed() self._remove_old_if_needed()
def pin(self, key: Hashable) -> None:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
"""
if key not in self.cache:
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)
def _unpin(self, key: Hashable) -> None:
self.pinned_items.remove(key)
def _on_remove(self, key: Hashable, value: Optional[T]): def _on_remove(self, key: Hashable, value: Optional[T]):
pass pass
def remove_oldest(self): def remove_oldest(self, remove_pinned=False):
if not self.cache: if not self.cache:
return return
key, value = self.cache.popitem(last=False)
self._on_remove(key, value) if not remove_pinned:
# pop the oldest item in the cache that is not pinned
lru_key = next(
(key for key in self.cache if key not in self.pinned_items),
ALL_PINNED_SENTINEL)
if lru_key is ALL_PINNED_SENTINEL:
raise RuntimeError("All items are pinned, "
"cannot remove oldest from the cache.")
else:
lru_key = next(iter(self.cache))
self.pop(lru_key)
def _remove_old_if_needed(self) -> None: def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity: while len(self.cache) > self.capacity:
@ -120,13 +150,16 @@ class LRUCache(Generic[T]):
default_value: Optional[T] = None) -> Optional[T]: default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache run_on_remove = key in self.cache
value: Optional[T] = self.cache.pop(key, default_value) value: Optional[T] = self.cache.pop(key, default_value)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
if run_on_remove: if run_on_remove:
self._on_remove(key, value) self._on_remove(key, value)
return value return value
def clear(self): def clear(self):
while len(self.cache) > 0: while len(self.cache) > 0:
self.remove_oldest() self.remove_oldest(remove_pinned=True)
self.cache.clear() self.cache.clear()

View File

@ -878,6 +878,11 @@ class ModelRunner:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_lora(lora_id) return self.lora_manager.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_lora(lora_id)
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")

View File

@ -333,6 +333,9 @@ class Worker(WorkerBase):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id) return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
return self.model_runner.list_loras() return self.model_runner.list_loras()

View File

@ -70,6 +70,10 @@ class WorkerBase(ABC):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError
@abstractmethod @abstractmethod
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
raise NotImplementedError raise NotImplementedError
@ -86,6 +90,10 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA") raise ValueError(f"{type(self)} does not support LoRA")
def pin_lora(self, lora_id: int) -> bool:
return ValueError(
f"{type(self)} does not support LoRA") # type: ignore
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA") raise ValueError(f"{type(self)} does not support LoRA")