[core][misc] remove logical block (#5882)
This commit is contained in:
parent
79c92c7c8a
commit
64e8d2a783
@ -1,90 +1,10 @@
|
|||||||
"""Token blocks."""
|
"""Token blocks."""
|
||||||
import weakref
|
from typing import List
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from vllm.utils import Device
|
from vllm.utils import Device
|
||||||
|
|
||||||
_BLANK_TOKEN_ID = -1
|
|
||||||
|
|
||||||
DEFAULT_LAST_ACCESSED_TIME = -1
|
DEFAULT_LAST_ACCESSED_TIME = -1
|
||||||
|
|
||||||
TokensBlock = List[int]
|
|
||||||
|
|
||||||
|
|
||||||
class BlockPool:
|
|
||||||
"""A pool of logical blocks.
|
|
||||||
When requests come, we create a lot of logical blocks;
|
|
||||||
when requests are done, we destroy a lot of logical blocks.
|
|
||||||
It turns out that creating and destroying logical blocks can be expensive,
|
|
||||||
especially for the `token_ids` field, which is a list of integers.
|
|
||||||
To avoid this overhead, we use a pool to manage the logical blocks.
|
|
||||||
When an old request is done and a new request comes, we can reuse the
|
|
||||||
logical blocks from the old request to feed the new request.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
# block size to list of token blocks
|
|
||||||
self.pool: Dict[int, List[TokensBlock]] = defaultdict(list)
|
|
||||||
|
|
||||||
def alloc_block(self, block_size: int) -> TokensBlock:
|
|
||||||
if block_size in self.pool and self.pool[block_size]:
|
|
||||||
return self.pool[block_size].pop()
|
|
||||||
return [_BLANK_TOKEN_ID] * block_size
|
|
||||||
|
|
||||||
def del_block(self, block: TokensBlock) -> None:
|
|
||||||
self.pool[len(block)].append(block)
|
|
||||||
|
|
||||||
|
|
||||||
_BLOCK_POOL = BlockPool()
|
|
||||||
|
|
||||||
|
|
||||||
class LogicalTokenBlock:
|
|
||||||
"""A block that stores a contiguous chunk of tokens from left to right.
|
|
||||||
|
|
||||||
Logical blocks are used to represent the states of the corresponding
|
|
||||||
physical blocks in the KV cache.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
block_number: int,
|
|
||||||
block_size: int,
|
|
||||||
) -> None:
|
|
||||||
self.block_number = block_number
|
|
||||||
self.block_size = block_size
|
|
||||||
|
|
||||||
self.token_ids = _BLOCK_POOL.alloc_block(block_size)
|
|
||||||
# this finalizer is used to return the block to the pool when the object is deleted # noqa
|
|
||||||
# NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa
|
|
||||||
# i.e. `self.token_ids` may be deleted before `self`, and we lose
|
|
||||||
# the opportunity to return the block to the pool
|
|
||||||
self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block,
|
|
||||||
self.token_ids)
|
|
||||||
self.num_tokens = 0
|
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
|
||||||
return self.num_tokens == 0
|
|
||||||
|
|
||||||
def get_num_empty_slots(self) -> int:
|
|
||||||
return self.block_size - self.num_tokens
|
|
||||||
|
|
||||||
def is_full(self) -> bool:
|
|
||||||
return self.num_tokens == self.block_size
|
|
||||||
|
|
||||||
def append_tokens(self, token_ids: List[int]) -> None:
|
|
||||||
assert len(token_ids) <= self.get_num_empty_slots()
|
|
||||||
curr_idx = self.num_tokens
|
|
||||||
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
|
|
||||||
self.num_tokens += len(token_ids)
|
|
||||||
|
|
||||||
def get_token_ids(self) -> List[int]:
|
|
||||||
return self.token_ids[:self.num_tokens]
|
|
||||||
|
|
||||||
def get_last_token_id(self) -> int:
|
|
||||||
assert self.num_tokens > 0
|
|
||||||
return self.token_ids[self.num_tokens - 1]
|
|
||||||
|
|
||||||
|
|
||||||
class PhysicalTokenBlock:
|
class PhysicalTokenBlock:
|
||||||
"""Represents the state of a block in the KV cache."""
|
"""Represents the state of a block in the KV cache."""
|
||||||
|
|||||||
@ -262,8 +262,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
self.cross_block_tables: Dict[str, BlockTable] = {}
|
self.cross_block_tables: Dict[str, BlockTable] = {}
|
||||||
|
|
||||||
def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
|
def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
|
||||||
return 0 if seq is None \
|
return 0 if seq is None else seq.n_blocks
|
||||||
else len(seq.logical_token_blocks)
|
|
||||||
|
|
||||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||||
@ -298,7 +297,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
ref_count: int, \
|
ref_count: int, \
|
||||||
is_encoder_decoder: bool = True) -> BlockTable:
|
is_encoder_decoder: bool = True) -> BlockTable:
|
||||||
# Allocate new physical token blocks that will store the prompt tokens.
|
# Allocate new physical token blocks that will store the prompt tokens.
|
||||||
num_prompt_blocks = len(seq.logical_token_blocks)
|
num_prompt_blocks = seq.n_blocks
|
||||||
|
|
||||||
block_table: BlockTable = []
|
block_table: BlockTable = []
|
||||||
for logical_idx in range(num_prompt_blocks):
|
for logical_idx in range(num_prompt_blocks):
|
||||||
@ -367,7 +366,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
|
|
||||||
# Compute a new hash for the block so that it can be shared by other
|
# Compute a new hash for the block so that it can be shared by other
|
||||||
# Sequences
|
# Sequences
|
||||||
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
new_hash = seq.hash_of_block(seq.n_blocks - 1)
|
||||||
|
|
||||||
# if new_hash is already in the cached table, then free last_block
|
# if new_hash is already in the cached table, then free last_block
|
||||||
# and return the cached version
|
# and return the cached version
|
||||||
@ -407,10 +406,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
if not self.enable_caching:
|
if not self.enable_caching:
|
||||||
return self.gpu_allocator.allocate()
|
return self.gpu_allocator.allocate()
|
||||||
block_hash: Optional[int] = None
|
block_hash: Optional[int] = None
|
||||||
|
n_blocks = seq.n_blocks
|
||||||
if (self._is_last_block_full(seq)):
|
if (self._is_last_block_full(seq)):
|
||||||
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
block_hash = seq.hash_of_block(n_blocks - 1)
|
||||||
num_hashed_tokens = seq.num_hashed_tokens_of_block(
|
num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1)
|
||||||
len(seq.logical_token_blocks) - 1)
|
|
||||||
|
|
||||||
# num_hashed_tokens is used to compute future hashes
|
# num_hashed_tokens is used to compute future hashes
|
||||||
# (e.g. in the hashing function, it is used to ask the sequence for
|
# (e.g. in the hashing function, it is used to ask the sequence for
|
||||||
@ -429,12 +428,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
num_lookahead_slots: int = 0,
|
num_lookahead_slots: int = 0,
|
||||||
) -> List[Tuple[int, int]]:
|
) -> List[Tuple[int, int]]:
|
||||||
"""Allocate a physical slot for a new token."""
|
"""Allocate a physical slot for a new token."""
|
||||||
logical_blocks = seq.logical_token_blocks
|
n_blocks = seq.n_blocks
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
# If we need to allocate a new physical block
|
# If we need to allocate a new physical block
|
||||||
if len(block_table) < len(logical_blocks):
|
if len(block_table) < n_blocks:
|
||||||
# Currently this code only supports adding one physical block
|
# Currently this code only supports adding one physical block
|
||||||
assert len(block_table) == len(logical_blocks) - 1
|
assert len(block_table) == n_blocks - 1
|
||||||
|
|
||||||
if (self.block_sliding_window
|
if (self.block_sliding_window
|
||||||
and len(block_table) >= self.block_sliding_window):
|
and len(block_table) >= self.block_sliding_window):
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
"""Sequence and its related classes."""
|
"""Sequence and its related classes."""
|
||||||
import copy
|
import copy
|
||||||
import enum
|
import enum
|
||||||
|
import math
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.block import LogicalTokenBlock
|
|
||||||
from vllm.inputs import LLMInputs
|
from vllm.inputs import LLMInputs
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
@ -236,9 +236,6 @@ class Sequence:
|
|||||||
self.output_logprobs: SampleLogprobs = []
|
self.output_logprobs: SampleLogprobs = []
|
||||||
self.output_text = ""
|
self.output_text = ""
|
||||||
|
|
||||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
|
||||||
# Initialize the logical token blocks with the prompt token ids.
|
|
||||||
self._append_tokens_to_blocks(self.prompt_token_ids)
|
|
||||||
self.status = SequenceStatus.WAITING
|
self.status = SequenceStatus.WAITING
|
||||||
self.stop_reason: Union[int, str, None] = None
|
self.stop_reason: Union[int, str, None] = None
|
||||||
|
|
||||||
@ -248,6 +245,10 @@ class Sequence:
|
|||||||
# Input + output tokens
|
# Input + output tokens
|
||||||
self.tokens: Optional[List[str]] = None
|
self.tokens: Optional[List[str]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_blocks(self) -> int:
|
||||||
|
return math.ceil(self.get_len() / self.block_size)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt(self) -> Optional[str]:
|
def prompt(self) -> Optional[str]:
|
||||||
return self.inputs.get("prompt")
|
return self.inputs.get("prompt")
|
||||||
@ -287,36 +288,12 @@ class Sequence:
|
|||||||
"""Reset the sequence states for recomputation."""
|
"""Reset the sequence states for recomputation."""
|
||||||
self.data.reset_state_for_recompute()
|
self.data.reset_state_for_recompute()
|
||||||
|
|
||||||
def _append_logical_block(self) -> None:
|
|
||||||
block = LogicalTokenBlock(
|
|
||||||
block_number=len(self.logical_token_blocks),
|
|
||||||
block_size=self.block_size,
|
|
||||||
)
|
|
||||||
self.logical_token_blocks.append(block)
|
|
||||||
|
|
||||||
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
|
|
||||||
cursor = 0
|
|
||||||
while cursor < len(token_ids):
|
|
||||||
if not self.logical_token_blocks:
|
|
||||||
self._append_logical_block()
|
|
||||||
|
|
||||||
last_block = self.logical_token_blocks[-1]
|
|
||||||
if last_block.is_full():
|
|
||||||
self._append_logical_block()
|
|
||||||
last_block = self.logical_token_blocks[-1]
|
|
||||||
|
|
||||||
num_empty_slots = last_block.get_num_empty_slots()
|
|
||||||
last_block.append_tokens(token_ids[cursor:cursor +
|
|
||||||
num_empty_slots])
|
|
||||||
cursor += num_empty_slots
|
|
||||||
|
|
||||||
def append_token_id(
|
def append_token_id(
|
||||||
self,
|
self,
|
||||||
token_id: int,
|
token_id: int,
|
||||||
logprobs: Dict[int, Logprob],
|
logprobs: Dict[int, Logprob],
|
||||||
) -> None:
|
) -> None:
|
||||||
assert token_id in logprobs
|
assert token_id in logprobs
|
||||||
self._append_tokens_to_blocks([token_id])
|
|
||||||
self.output_logprobs.append(logprobs)
|
self.output_logprobs.append(logprobs)
|
||||||
self.data.append_token_id(token_id, logprobs[token_id].logprob)
|
self.data.append_token_id(token_id, logprobs[token_id].logprob)
|
||||||
|
|
||||||
@ -388,7 +365,7 @@ class Sequence:
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"Sequence(seq_id={self.seq_id}, "
|
return (f"Sequence(seq_id={self.seq_id}, "
|
||||||
f"status={self.status.name}, "
|
f"status={self.status.name}, "
|
||||||
f"num_blocks={len(self.logical_token_blocks)})")
|
f"num_blocks={self.n_blocks}, ")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user