[Core] Add generic typing to LRUCache (#3511)

This commit is contained in:
Nick Hill 2024-03-20 00:36:09 -07:00 committed by GitHub
parent 9474e89ba4
commit 4ad521d8b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 29 additions and 22 deletions

View File

@ -4,7 +4,7 @@ import logging
import math import math
import os import os
import re import re
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type) from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type)
import safetensors.torch import safetensors.torch
import torch import torch
@ -535,14 +535,14 @@ class LoRAModelManager:
replacement_loras) replacement_loras)
class LoRALRUCache(LRUCache): class LoRALRUCache(LRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]): None]):
super().__init__(capacity) super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: Hashable, value: Any): def _on_remove(self, key: Hashable, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}") logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key) self.deactivate_lora_fn(key)
return super()._on_remove(key, value) return super()._on_remove(key, value)

View File

@ -22,27 +22,34 @@ class BaseTokenizerGroup(ABC):
pass pass
@abstractmethod @abstractmethod
def encode(self, prompt: str, request_id: Optional[str], def encode(self,
lora_request: Optional[LoRARequest]) -> List[int]: prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.""" """Encode a prompt using the tokenizer group."""
pass pass
@abstractmethod @abstractmethod
async def encode_async(self, prompt: str, request_id: Optional[str], async def encode_async(
lora_request: Optional[LoRARequest]) -> List[int]: self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.""" """Encode a prompt using the tokenizer group."""
pass pass
@abstractmethod @abstractmethod
def get_lora_tokenizer( def get_lora_tokenizer(
self, self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request.""" """Get a tokenizer for a LoRA request."""
pass pass
@abstractmethod @abstractmethod
async def get_lora_tokenizer_async( async def get_lora_tokenizer_async(
self, self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request.""" """Get a tokenizer for a LoRA request."""
pass pass

View File

@ -21,10 +21,8 @@ class TokenizerGroup(BaseTokenizerGroup):
self.enable_lora = enable_lora self.enable_lora = enable_lora
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora: self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
self.lora_tokenizers = LRUCache(capacity=max_num_seqs) capacity=max_num_seqs) if enable_lora else None
else:
self.lora_tokenizers = None
def ping(self) -> bool: def ping(self) -> bool:
"""Check if the tokenizer group is alive.""" """Check if the tokenizer group is alive."""

View File

@ -5,7 +5,7 @@ import subprocess
import uuid import uuid
import gc import gc
from platform import uname from platform import uname
from typing import List, Tuple, Union from typing import List, Tuple, Union, Generic
from packaging.version import parse, Version from packaging.version import parse, Version
import psutil import psutil
@ -53,10 +53,10 @@ class Counter:
self.counter = 0 self.counter = 0
class LRUCache: class LRUCache(Generic[T]):
def __init__(self, capacity: int): def __init__(self, capacity: int):
self.cache = OrderedDict() self.cache = OrderedDict[Hashable, T]()
self.capacity = capacity self.capacity = capacity
def __contains__(self, key: Hashable) -> bool: def __contains__(self, key: Hashable) -> bool:
@ -65,10 +65,10 @@ class LRUCache:
def __len__(self) -> int: def __len__(self) -> int:
return len(self.cache) return len(self.cache)
def __getitem__(self, key: Hashable) -> Any: def __getitem__(self, key: Hashable) -> T:
return self.get(key) return self.get(key)
def __setitem__(self, key: Hashable, value: Any) -> None: def __setitem__(self, key: Hashable, value: T) -> None:
self.put(key, value) self.put(key, value)
def __delitem__(self, key: Hashable) -> None: def __delitem__(self, key: Hashable) -> None:
@ -77,7 +77,9 @@ class LRUCache:
def touch(self, key: Hashable) -> None: def touch(self, key: Hashable) -> None:
self.cache.move_to_end(key) self.cache.move_to_end(key)
def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
if key in self.cache: if key in self.cache:
value = self.cache[key] value = self.cache[key]
self.cache.move_to_end(key) self.cache.move_to_end(key)
@ -85,12 +87,12 @@ class LRUCache:
value = default_value value = default_value
return value return value
def put(self, key: Hashable, value: Any) -> None: def put(self, key: Hashable, value: T) -> None:
self.cache[key] = value self.cache[key] = value
self.cache.move_to_end(key) self.cache.move_to_end(key)
self._remove_old_if_needed() self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: Any): def _on_remove(self, key: Hashable, value: T):
pass pass
def remove_oldest(self): def remove_oldest(self):
@ -103,7 +105,7 @@ class LRUCache:
while len(self.cache) > self.capacity: while len(self.cache) > self.capacity:
self.remove_oldest() self.remove_oldest()
def pop(self, key: int, default_value: Optional[Any] = None) -> Any: def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
run_on_remove = key in self.cache run_on_remove = key in self.cache
value = self.cache.pop(key, default_value) value = self.cache.pop(key, default_value)
if run_on_remove: if run_on_remove: