[core][misc] remove logical block (#5882)

This commit is contained in:
youkaichao 2024-06-27 13:34:55 -07:00 committed by GitHub
parent 79c92c7c8a
commit 64e8d2a783
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 120 deletions

View File

@ -1,90 +1,10 @@
"""Token blocks.""" """Token blocks."""
import weakref from typing import List
from collections import defaultdict
from typing import Dict, List
from vllm.utils import Device from vllm.utils import Device
_BLANK_TOKEN_ID = -1
DEFAULT_LAST_ACCESSED_TIME = -1 DEFAULT_LAST_ACCESSED_TIME = -1
TokensBlock = List[int]
class BlockPool:
"""A pool of logical blocks.
When requests come, we create a lot of logical blocks;
when requests are done, we destroy a lot of logical blocks.
It turns out that creating and destroying logical blocks can be expensive,
especially for the `token_ids` field, which is a list of integers.
To avoid this overhead, we use a pool to manage the logical blocks.
When an old request is done and a new request comes, we can reuse the
logical blocks from the old request to feed the new request.
"""
def __init__(self) -> None:
# block size to list of token blocks
self.pool: Dict[int, List[TokensBlock]] = defaultdict(list)
def alloc_block(self, block_size: int) -> TokensBlock:
if block_size in self.pool and self.pool[block_size]:
return self.pool[block_size].pop()
return [_BLANK_TOKEN_ID] * block_size
def del_block(self, block: TokensBlock) -> None:
self.pool[len(block)].append(block)
_BLOCK_POOL = BlockPool()
class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right.
Logical blocks are used to represent the states of the corresponding
physical blocks in the KV cache.
"""
def __init__(
self,
block_number: int,
block_size: int,
) -> None:
self.block_number = block_number
self.block_size = block_size
self.token_ids = _BLOCK_POOL.alloc_block(block_size)
# this finalizer is used to return the block to the pool when the object is deleted # noqa
# NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa
# i.e. `self.token_ids` may be deleted before `self`, and we lose
# the opportunity to return the block to the pool
self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block,
self.token_ids)
self.num_tokens = 0
def is_empty(self) -> bool:
return self.num_tokens == 0
def get_num_empty_slots(self) -> int:
return self.block_size - self.num_tokens
def is_full(self) -> bool:
return self.num_tokens == self.block_size
def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots()
curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids)
def get_token_ids(self) -> List[int]:
return self.token_ids[:self.num_tokens]
def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]
class PhysicalTokenBlock: class PhysicalTokenBlock:
"""Represents the state of a block in the KV cache.""" """Represents the state of a block in the KV cache."""

View File

@ -262,8 +262,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self.cross_block_tables: Dict[str, BlockTable] = {} self.cross_block_tables: Dict[str, BlockTable] = {}
def _get_seq_num_required_blocks(self, seq: Sequence) -> int: def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
return 0 if seq is None \ return 0 if seq is None else seq.n_blocks
else len(seq.logical_token_blocks)
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share # FIXME(woosuk): Here we assume that all sequences in the group share
@ -298,7 +297,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
ref_count: int, \ ref_count: int, \
is_encoder_decoder: bool = True) -> BlockTable: is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks) num_prompt_blocks = seq.n_blocks
block_table: BlockTable = [] block_table: BlockTable = []
for logical_idx in range(num_prompt_blocks): for logical_idx in range(num_prompt_blocks):
@ -367,7 +366,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Compute a new hash for the block so that it can be shared by other # Compute a new hash for the block so that it can be shared by other
# Sequences # Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) new_hash = seq.hash_of_block(seq.n_blocks - 1)
# if new_hash is already in the cached table, then free last_block # if new_hash is already in the cached table, then free last_block
# and return the cached version # and return the cached version
@ -407,10 +406,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if not self.enable_caching: if not self.enable_caching:
return self.gpu_allocator.allocate() return self.gpu_allocator.allocate()
block_hash: Optional[int] = None block_hash: Optional[int] = None
n_blocks = seq.n_blocks
if (self._is_last_block_full(seq)): if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) block_hash = seq.hash_of_block(n_blocks - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block( num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1)
len(seq.logical_token_blocks) - 1)
# num_hashed_tokens is used to compute future hashes # num_hashed_tokens is used to compute future hashes
# (e.g. in the hashing function, it is used to ask the sequence for # (e.g. in the hashing function, it is used to ask the sequence for
@ -429,12 +428,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
num_lookahead_slots: int = 0, num_lookahead_slots: int = 0,
) -> List[Tuple[int, int]]: ) -> List[Tuple[int, int]]:
"""Allocate a physical slot for a new token.""" """Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks n_blocks = seq.n_blocks
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
# If we need to allocate a new physical block # If we need to allocate a new physical block
if len(block_table) < len(logical_blocks): if len(block_table) < n_blocks:
# Currently this code only supports adding one physical block # Currently this code only supports adding one physical block
assert len(block_table) == len(logical_blocks) - 1 assert len(block_table) == n_blocks - 1
if (self.block_sliding_window if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window): and len(block_table) >= self.block_sliding_window):

View File

@ -1,13 +1,13 @@
"""Sequence and its related classes.""" """Sequence and its related classes."""
import copy import copy
import enum import enum
import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
from vllm.block import LogicalTokenBlock
from vllm.inputs import LLMInputs from vllm.inputs import LLMInputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
@ -236,9 +236,6 @@ class Sequence:
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
self.output_text = "" self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens_to_blocks(self.prompt_token_ids)
self.status = SequenceStatus.WAITING self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None self.stop_reason: Union[int, str, None] = None
@ -248,6 +245,10 @@ class Sequence:
# Input + output tokens # Input + output tokens
self.tokens: Optional[List[str]] = None self.tokens: Optional[List[str]] = None
@property
def n_blocks(self) -> int:
return math.ceil(self.get_len() / self.block_size)
@property @property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
return self.inputs.get("prompt") return self.inputs.get("prompt")
@ -287,36 +288,12 @@ class Sequence:
"""Reset the sequence states for recomputation.""" """Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute() self.data.reset_state_for_recompute()
def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks),
block_size=self.block_size,
)
self.logical_token_blocks.append(block)
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
cursor = 0
while cursor < len(token_ids):
if not self.logical_token_blocks:
self._append_logical_block()
last_block = self.logical_token_blocks[-1]
if last_block.is_full():
self._append_logical_block()
last_block = self.logical_token_blocks[-1]
num_empty_slots = last_block.get_num_empty_slots()
last_block.append_tokens(token_ids[cursor:cursor +
num_empty_slots])
cursor += num_empty_slots
def append_token_id( def append_token_id(
self, self,
token_id: int, token_id: int,
logprobs: Dict[int, Logprob], logprobs: Dict[int, Logprob],
) -> None: ) -> None:
assert token_id in logprobs assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs) self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob) self.data.append_token_id(token_id, logprobs[token_id].logprob)
@ -388,7 +365,7 @@ class Sequence:
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"Sequence(seq_id={self.seq_id}, " return (f"Sequence(seq_id={self.seq_id}, "
f"status={self.status.name}, " f"status={self.status.name}, "
f"num_blocks={len(self.logical_token_blocks)})") f"num_blocks={self.n_blocks}, ")
@dataclass @dataclass