From 4ad521d8b51145a55c1be6b8e451f76423cc2d87 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 20 Mar 2024 00:36:09 -0700 Subject: [PATCH] [Core] Add generic typing to `LRUCache` (#3511) --- vllm/lora/models.py | 6 +++--- .../tokenizer_group/base_tokenizer_group.py | 19 ++++++++++++------ .../tokenizer_group/tokenizer_group.py | 6 ++---- vllm/utils.py | 20 ++++++++++--------- 4 files changed, 29 insertions(+), 22 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 238da256..6fe07b69 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,7 +4,7 @@ import logging import math import os import re -from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type) +from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type) import safetensors.torch import torch @@ -535,14 +535,14 @@ class LoRAModelManager: replacement_loras) -class LoRALRUCache(LRUCache): +class LoRALRUCache(LRUCache[LoRAModel]): def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], None]): super().__init__(capacity) self.deactivate_lora_fn = deactivate_lora_fn - def _on_remove(self, key: Hashable, value: Any): + def _on_remove(self, key: Hashable, value: LoRAModel): logger.debug(f"Removing LoRA. int id: {key}") self.deactivate_lora_fn(key) return super()._on_remove(key, value) diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index 99518a60..3cce96e0 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -22,27 +22,34 @@ class BaseTokenizerGroup(ABC): pass @abstractmethod - def encode(self, prompt: str, request_id: Optional[str], - lora_request: Optional[LoRARequest]) -> List[int]: + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass @abstractmethod - async def encode_async(self, prompt: str, request_id: Optional[str], - lora_request: Optional[LoRARequest]) -> List[int]: + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass @abstractmethod def get_lora_tokenizer( self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": """Get a tokenizer for a LoRA request.""" pass @abstractmethod async def get_lora_tokenizer_async( self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": """Get a tokenizer for a LoRA request.""" pass diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 3af1334c..ec20d0fb 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -21,10 +21,8 @@ class TokenizerGroup(BaseTokenizerGroup): self.enable_lora = enable_lora self.max_input_length = max_input_length self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) - if enable_lora: - self.lora_tokenizers = LRUCache(capacity=max_num_seqs) - else: - self.lora_tokenizers = None + self.lora_tokenizers = LRUCache[PreTrainedTokenizer]( + capacity=max_num_seqs) if enable_lora else None def ping(self) -> bool: """Check if the tokenizer group is alive.""" diff --git a/vllm/utils.py b/vllm/utils.py index 7c73062e..8fa372b5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,7 +5,7 @@ import subprocess import uuid import gc from platform import uname -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Generic from packaging.version import parse, Version import psutil @@ -53,10 +53,10 @@ class Counter: self.counter = 0 -class LRUCache: +class LRUCache(Generic[T]): def __init__(self, capacity: int): - self.cache = OrderedDict() + self.cache = OrderedDict[Hashable, T]() self.capacity = capacity def __contains__(self, key: Hashable) -> bool: @@ -65,10 +65,10 @@ class LRUCache: def __len__(self) -> int: return len(self.cache) - def __getitem__(self, key: Hashable) -> Any: + def __getitem__(self, key: Hashable) -> T: return self.get(key) - def __setitem__(self, key: Hashable, value: Any) -> None: + def __setitem__(self, key: Hashable, value: T) -> None: self.put(key, value) def __delitem__(self, key: Hashable) -> None: @@ -77,7 +77,9 @@ class LRUCache: def touch(self, key: Hashable) -> None: self.cache.move_to_end(key) - def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: + def get(self, + key: Hashable, + default_value: Optional[T] = None) -> Optional[T]: if key in self.cache: value = self.cache[key] self.cache.move_to_end(key) @@ -85,12 +87,12 @@ class LRUCache: value = default_value return value - def put(self, key: Hashable, value: Any) -> None: + def put(self, key: Hashable, value: T) -> None: self.cache[key] = value self.cache.move_to_end(key) self._remove_old_if_needed() - def _on_remove(self, key: Hashable, value: Any): + def _on_remove(self, key: Hashable, value: T): pass def remove_oldest(self): @@ -103,7 +105,7 @@ class LRUCache: while len(self.cache) > self.capacity: self.remove_oldest() - def pop(self, key: int, default_value: Optional[Any] = None) -> Any: + def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T: run_on_remove = key in self.cache value = self.cache.pop(key, default_value) if run_on_remove: