[Core] Add generic typing to LRUCache (#3511)
This commit is contained in:
parent
9474e89ba4
commit
4ad521d8b5
@ -4,7 +4,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type)
|
from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type)
|
||||||
|
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
@ -535,14 +535,14 @@ class LoRAModelManager:
|
|||||||
replacement_loras)
|
replacement_loras)
|
||||||
|
|
||||||
|
|
||||||
class LoRALRUCache(LRUCache):
|
class LoRALRUCache(LRUCache[LoRAModel]):
|
||||||
|
|
||||||
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
|
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
|
||||||
None]):
|
None]):
|
||||||
super().__init__(capacity)
|
super().__init__(capacity)
|
||||||
self.deactivate_lora_fn = deactivate_lora_fn
|
self.deactivate_lora_fn = deactivate_lora_fn
|
||||||
|
|
||||||
def _on_remove(self, key: Hashable, value: Any):
|
def _on_remove(self, key: Hashable, value: LoRAModel):
|
||||||
logger.debug(f"Removing LoRA. int id: {key}")
|
logger.debug(f"Removing LoRA. int id: {key}")
|
||||||
self.deactivate_lora_fn(key)
|
self.deactivate_lora_fn(key)
|
||||||
return super()._on_remove(key, value)
|
return super()._on_remove(key, value)
|
||||||
|
|||||||
@ -22,27 +22,34 @@ class BaseTokenizerGroup(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def encode(self, prompt: str, request_id: Optional[str],
|
def encode(self,
|
||||||
lora_request: Optional[LoRARequest]) -> List[int]:
|
prompt: str,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||||
"""Encode a prompt using the tokenizer group."""
|
"""Encode a prompt using the tokenizer group."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def encode_async(self, prompt: str, request_id: Optional[str],
|
async def encode_async(
|
||||||
lora_request: Optional[LoRARequest]) -> List[int]:
|
self,
|
||||||
|
prompt: str,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||||
"""Encode a prompt using the tokenizer group."""
|
"""Encode a prompt using the tokenizer group."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_lora_tokenizer(
|
def get_lora_tokenizer(
|
||||||
self,
|
self,
|
||||||
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
|
lora_request: Optional[LoRARequest] = None
|
||||||
|
) -> "PreTrainedTokenizer":
|
||||||
"""Get a tokenizer for a LoRA request."""
|
"""Get a tokenizer for a LoRA request."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_lora_tokenizer_async(
|
async def get_lora_tokenizer_async(
|
||||||
self,
|
self,
|
||||||
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
|
lora_request: Optional[LoRARequest] = None
|
||||||
|
) -> "PreTrainedTokenizer":
|
||||||
"""Get a tokenizer for a LoRA request."""
|
"""Get a tokenizer for a LoRA request."""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -21,10 +21,8 @@ class TokenizerGroup(BaseTokenizerGroup):
|
|||||||
self.enable_lora = enable_lora
|
self.enable_lora = enable_lora
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
||||||
if enable_lora:
|
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
|
||||||
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
|
capacity=max_num_seqs) if enable_lora else None
|
||||||
else:
|
|
||||||
self.lora_tokenizers = None
|
|
||||||
|
|
||||||
def ping(self) -> bool:
|
def ping(self) -> bool:
|
||||||
"""Check if the tokenizer group is alive."""
|
"""Check if the tokenizer group is alive."""
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import subprocess
|
|||||||
import uuid
|
import uuid
|
||||||
import gc
|
import gc
|
||||||
from platform import uname
|
from platform import uname
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union, Generic
|
||||||
from packaging.version import parse, Version
|
from packaging.version import parse, Version
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
@ -53,10 +53,10 @@ class Counter:
|
|||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
|
|
||||||
class LRUCache:
|
class LRUCache(Generic[T]):
|
||||||
|
|
||||||
def __init__(self, capacity: int):
|
def __init__(self, capacity: int):
|
||||||
self.cache = OrderedDict()
|
self.cache = OrderedDict[Hashable, T]()
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
|
|
||||||
def __contains__(self, key: Hashable) -> bool:
|
def __contains__(self, key: Hashable) -> bool:
|
||||||
@ -65,10 +65,10 @@ class LRUCache:
|
|||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.cache)
|
return len(self.cache)
|
||||||
|
|
||||||
def __getitem__(self, key: Hashable) -> Any:
|
def __getitem__(self, key: Hashable) -> T:
|
||||||
return self.get(key)
|
return self.get(key)
|
||||||
|
|
||||||
def __setitem__(self, key: Hashable, value: Any) -> None:
|
def __setitem__(self, key: Hashable, value: T) -> None:
|
||||||
self.put(key, value)
|
self.put(key, value)
|
||||||
|
|
||||||
def __delitem__(self, key: Hashable) -> None:
|
def __delitem__(self, key: Hashable) -> None:
|
||||||
@ -77,7 +77,9 @@ class LRUCache:
|
|||||||
def touch(self, key: Hashable) -> None:
|
def touch(self, key: Hashable) -> None:
|
||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
|
|
||||||
def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
|
def get(self,
|
||||||
|
key: Hashable,
|
||||||
|
default_value: Optional[T] = None) -> Optional[T]:
|
||||||
if key in self.cache:
|
if key in self.cache:
|
||||||
value = self.cache[key]
|
value = self.cache[key]
|
||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
@ -85,12 +87,12 @@ class LRUCache:
|
|||||||
value = default_value
|
value = default_value
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def put(self, key: Hashable, value: Any) -> None:
|
def put(self, key: Hashable, value: T) -> None:
|
||||||
self.cache[key] = value
|
self.cache[key] = value
|
||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
self._remove_old_if_needed()
|
self._remove_old_if_needed()
|
||||||
|
|
||||||
def _on_remove(self, key: Hashable, value: Any):
|
def _on_remove(self, key: Hashable, value: T):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def remove_oldest(self):
|
def remove_oldest(self):
|
||||||
@ -103,7 +105,7 @@ class LRUCache:
|
|||||||
while len(self.cache) > self.capacity:
|
while len(self.cache) > self.capacity:
|
||||||
self.remove_oldest()
|
self.remove_oldest()
|
||||||
|
|
||||||
def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
|
def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
|
||||||
run_on_remove = key in self.cache
|
run_on_remove = key in self.cache
|
||||||
value = self.cache.pop(key, default_value)
|
value = self.cache.pop(key, default_value)
|
||||||
if run_on_remove:
|
if run_on_remove:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user