[Core] Sliding window for block manager v2 (#4545)
Co-authored-by: Ruth Evans <ruthevans@Ruths-MacBook-Pro.local>
This commit is contained in:
parent
890aa93d27
commit
d4f3985907
@ -1,3 +1,5 @@
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
@ -40,3 +42,27 @@ def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
for llm in generator_inner():
|
||||
yield llm
|
||||
del llm
|
||||
|
||||
|
||||
def get_text_from_llm_generator(llm_generator: Iterable[LLM],
|
||||
prompts,
|
||||
sampling_params,
|
||||
llm_cb: Optional[Callable[[LLM],
|
||||
None]] = None):
|
||||
for llm in llm_generator:
|
||||
if llm_cb:
|
||||
llm_cb(llm)
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
text = [output.outputs[0].text for output in outputs]
|
||||
del llm
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
del llm
|
||||
|
||||
return token_ids
|
||||
|
||||
@ -4,6 +4,8 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_token_ids_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
@ -444,12 +446,3 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
|
||||
assert expected_token_ids == actual_token_ids
|
||||
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
del llm
|
||||
|
||||
return token_ids
|
||||
|
||||
168
tests/core/block/e2e/test_correctness_sliding_window.py
Normal file
168
tests/core/block/e2e/test_correctness_sliding_window.py
Normal file
@ -0,0 +1,168 @@
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from .conftest import get_text_from_llm_generator
|
||||
|
||||
# relatively small model with 4k sliding window
|
||||
MODEL = "bigcode/starcoder2-3b"
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": MODEL,
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
"block_size": BLOCK_SIZE,
|
||||
# needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008
|
||||
"num_gpu_blocks_override": 100000 // BLOCK_SIZE,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"use_v2_block_manager": False
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
|
||||
@pytest.mark.parametrize("batch_size", [5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
|
||||
batch_size, seed):
|
||||
"""
|
||||
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
|
||||
asks for value of one of them (which is outside the sliding window).
|
||||
If we tell it upfront which we are going to be looking for, then
|
||||
it answers correctly (mostly).
|
||||
|
||||
Additionally, we compare the results of the v1 and v2 managers.
|
||||
"""
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=1024,
|
||||
ignore_eos=True,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
prompts, answer, indices = prep_prompts(batch_size)
|
||||
|
||||
print('Getting token ids from block manager v1')
|
||||
baseline_texts = get_text_from_llm_generator(baseline_llm_generator,
|
||||
prompts,
|
||||
sampling_params,
|
||||
llm_cb=check_window(prompts))
|
||||
|
||||
check_answers(indices, answer, baseline_texts)
|
||||
|
||||
print('Getting token ids from block manager v2')
|
||||
test_texts = get_text_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
check_answers(indices, answer, test_texts)
|
||||
|
||||
cmp = [
|
||||
expected_text == actual_text
|
||||
for expected_text, actual_text in zip(baseline_texts, test_texts)
|
||||
]
|
||||
print(cmp)
|
||||
# make sure it's mostly OK; this is possibly because https://github.com/vllm-project/vllm/pull/4768
|
||||
# however, https://github.com/vllm-project/vllm/issues/3385#issuecomment-1995924290
|
||||
# states that xformers and flash_attn have different ideas about the window
|
||||
# size anyways
|
||||
assert sum(cmp) > 0.7 * len(cmp)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": MODEL,
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
"block_size": BLOCK_SIZE,
|
||||
"num_gpu_blocks_override": 100000 // BLOCK_SIZE,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"use_v2_block_manager": True,
|
||||
"enable_chunked_prefill": True
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
|
||||
"""
|
||||
This is similar to test_sliding_window_retrival, however, it doesn't
|
||||
compare against the v1 block manager since v1 doesn't support
|
||||
chunked prefill with sliding window.
|
||||
|
||||
The results with and without chunked prefill are not the same due to
|
||||
numerical instabilities.
|
||||
"""
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=10,
|
||||
ignore_eos=True,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
prompts, answer, indices = prep_prompts(batch_size)
|
||||
|
||||
# We don't compare with the baseline model here, since the results
|
||||
# slightly different due to different tailing in attention.
|
||||
test_texts = get_text_from_llm_generator(test_llm_generator,
|
||||
prompts,
|
||||
sampling_params,
|
||||
llm_cb=check_window(prompts))
|
||||
check_answers(indices, answer, test_texts)
|
||||
|
||||
|
||||
def prep_prompts(batch_size: int):
|
||||
"""
|
||||
Generate prompts which a bunch of assignments,
|
||||
then asking for the value of one of them.
|
||||
The prompt is just under 10k tokens; sliding window is 4k
|
||||
so the answer is outside sliding window, but should still be correct.
|
||||
"""
|
||||
prompts: List[str] = []
|
||||
answer: List[int] = []
|
||||
indices: List[int] = []
|
||||
random.seed(1)
|
||||
for _ in range(batch_size):
|
||||
idx = random.randint(30, 90)
|
||||
indices.append(idx)
|
||||
prompt = "```python\n# We set a number of variables, " + \
|
||||
f"x{idx} will be important later\n"
|
||||
ln = random.randint(800, 1100)
|
||||
for k in range(30, ln):
|
||||
v = random.randint(10, 99)
|
||||
if k == idx:
|
||||
answer.append(v)
|
||||
prompt += f"x{k} = {v}\n"
|
||||
prompt += f"# Now, we check the value of x{idx}:\n"
|
||||
prompt += f"assert x{idx} == "
|
||||
prompts.append(prompt)
|
||||
return prompts, answer, indices
|
||||
|
||||
|
||||
def check_answers(indices: List[int], answer: List[int], outputs: List[str]):
|
||||
answer2 = [int(text[0:2].strip()) for text in outputs]
|
||||
print(list(zip(indices, zip(answer, answer2))))
|
||||
numok = 0
|
||||
for a1, a2 in zip(answer, answer2):
|
||||
if a1 == a2:
|
||||
numok += 1
|
||||
frac_ok = numok / len(answer)
|
||||
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
|
||||
assert frac_ok > 0.7
|
||||
|
||||
|
||||
def check_window(prompts: List[str]):
|
||||
|
||||
def inner(llm: LLM):
|
||||
sliding_window = llm.llm_engine.model_config.get_sliding_window()
|
||||
assert sliding_window and sliding_window > 0
|
||||
assert any(
|
||||
len(llm.get_tokenizer().tokenize(prompt)) > sliding_window
|
||||
for prompt in prompts)
|
||||
|
||||
return inner
|
||||
@ -101,3 +101,72 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
|
||||
range(prompt_len + num_slots_to_append + num_lookahead_slots)),
|
||||
block_size)) - len(chunk_list(list(range(prompt_len)), block_size))
|
||||
assert num_consumed_blocks == expected_consumed_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [8, 16])
|
||||
@pytest.mark.parametrize("prompt_len", [10, 300, 1000])
|
||||
@pytest.mark.parametrize("num_slots_to_append", [50])
|
||||
@pytest.mark.parametrize("sliding_window", [20, 32, 200, 512])
|
||||
def test_sliding_window(block_size, prompt_len, num_slots_to_append,
|
||||
sliding_window):
|
||||
"""Verify append_slots consumes the correct number of blocks from the block
|
||||
table.
|
||||
"""
|
||||
|
||||
num_gpu_blocks = 1024
|
||||
watermark = 0.1
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
watermark=watermark,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
|
||||
def check_used(min_n, max_n=None):
|
||||
if max_n is None:
|
||||
max_n = min_n
|
||||
used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks()
|
||||
#print("check", min_n, used, max_n)
|
||||
assert min_n <= used
|
||||
assert used <= max_n
|
||||
|
||||
def num_blocks(num_tokens):
|
||||
return (num_tokens + block_size - 1) // block_size
|
||||
|
||||
check_used(0)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=prompt_len,
|
||||
seq_output_lens=[0],
|
||||
)
|
||||
|
||||
check_used(0)
|
||||
|
||||
# Allocate seq
|
||||
assert block_manager.can_allocate(seq_group)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
check_used(num_blocks(prompt_len))
|
||||
|
||||
# Seq seq to RUNNING
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
seq.data.update_num_computed_tokens(prompt_len)
|
||||
check_used(num_blocks(prompt_len))
|
||||
|
||||
# this is how we compute it in BlockSpaceManagerV2.__init__
|
||||
sliding_blocks = (sliding_window // block_size) + 2
|
||||
# plus one block for null block
|
||||
sliding_blocks += 1
|
||||
|
||||
# Append tokens to the sequeqnce
|
||||
for token_id in range(num_slots_to_append):
|
||||
seq.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
seq.data.update_num_computed_tokens(1)
|
||||
block_manager.append_slots(seq, num_lookahead_slots=0)
|
||||
if prompt_len < sliding_window + 10:
|
||||
check_used(0, sliding_blocks + 1)
|
||||
else:
|
||||
check_used(sliding_blocks, sliding_blocks + 1)
|
||||
|
||||
@ -697,6 +697,10 @@ if triton.__version__ >= "2.1.0":
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
|
||||
|
||||
# 0 means "disable"
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
num_warps = 8 if Lk <= 64 else 8
|
||||
if alibi_slopes is not None:
|
||||
_fwd_kernel_alibi[grid](
|
||||
@ -794,7 +798,7 @@ if triton.__version__ >= "2.1.0":
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
BLOCK_N=BLOCK,
|
||||
SLIDING_WINDOW=sliding_window if sliding_window is not None else 0,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
@ -20,6 +20,10 @@ class BlockTable:
|
||||
_blocks (Optional[List[Block]], optional): An optional list of existing
|
||||
blocks to initialize the BlockTable with. If not provided, an empty
|
||||
BlockTable is created.
|
||||
max_block_sliding_window (Optional[int], optional): The number of
|
||||
blocks to keep around for each sequance. If None, all blocks
|
||||
are kept (eg., when sliding window is not used).
|
||||
It should at least fit the sliding window size of the model.
|
||||
|
||||
Attributes:
|
||||
_block_size (int): The maximum number of tokens that can be stored in a
|
||||
@ -37,6 +41,7 @@ class BlockTable:
|
||||
block_size: int,
|
||||
block_allocator: DeviceAwareBlockAllocator,
|
||||
_blocks: Optional[List[Block]] = None,
|
||||
max_block_sliding_window: Optional[int] = None,
|
||||
):
|
||||
self._block_size = block_size
|
||||
self._allocator = block_allocator
|
||||
@ -44,6 +49,7 @@ class BlockTable:
|
||||
_blocks = []
|
||||
self._blocks: List[Block] = _blocks
|
||||
|
||||
self._max_block_sliding_window = max_block_sliding_window
|
||||
# Use helper method instead of directly calculating, as blocks
|
||||
# may not be allocated.
|
||||
self._num_full_slots = len(self._get_all_token_ids())
|
||||
@ -89,7 +95,8 @@ class BlockTable:
|
||||
|
||||
def append_token_ids(self,
|
||||
token_ids: List[int],
|
||||
num_lookahead_slots: int = 0) -> None:
|
||||
num_lookahead_slots: int = 0,
|
||||
num_computed_slots: Optional[int] = None) -> None:
|
||||
"""Appends a sequence of token IDs to the existing blocks in the
|
||||
BlockTable.
|
||||
|
||||
@ -104,13 +111,35 @@ class BlockTable:
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be appended.
|
||||
num_computed_slots (Optional[int]): The number of KV cache slots
|
||||
that are already filled (computed).
|
||||
When sliding window is enabled, this is used to compute how many
|
||||
blocks to drop at the front of the sequence.
|
||||
Without sliding window, None can be passed.
|
||||
Without chunked prefill, it should be the same as
|
||||
_num_full_slots.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
assert self._is_allocated, "no blocks have been allocated"
|
||||
assert len(self._blocks) > 0
|
||||
|
||||
# Drop blocks that are no longer needed due to sliding window
|
||||
if self._max_block_sliding_window is not None:
|
||||
null_block = self._allocator.allocate_or_get_null_block()
|
||||
assert num_computed_slots is not None
|
||||
end_block_idx = (num_computed_slots //
|
||||
self._block_size) - self._max_block_sliding_window
|
||||
for idx in range(0, end_block_idx):
|
||||
b = self._blocks[idx]
|
||||
if b is not null_block:
|
||||
self._allocator.free(b)
|
||||
self._blocks[idx] = null_block
|
||||
|
||||
# Ensure there are enough empty slots for the new tokens plus
|
||||
# lookahead slots
|
||||
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
|
||||
num_lookahead_slots)
|
||||
|
||||
# Update the blocks with the new tokens
|
||||
blocks = self._blocks[self._num_full_slots // self._block_size:]
|
||||
token_blocks = self._chunk_token_blocks_for_append(token_ids)
|
||||
|
||||
@ -168,6 +197,7 @@ class BlockTable:
|
||||
block_size=self._block_size,
|
||||
block_allocator=self._allocator,
|
||||
_blocks=forked_blocks,
|
||||
max_block_sliding_window=self._max_block_sliding_window,
|
||||
)
|
||||
|
||||
def free(self) -> None:
|
||||
|
||||
@ -105,11 +105,19 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
Device.GPU: gpu_block_allocator,
|
||||
}
|
||||
|
||||
self._null_block: Optional[Block] = None
|
||||
|
||||
self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
|
||||
for _, allocator in self._allocators.items():
|
||||
for block_id in allocator.all_block_ids:
|
||||
self._block_ids_to_allocator[block_id] = allocator
|
||||
|
||||
def allocate_or_get_null_block(self) -> Block:
|
||||
if self._null_block is None:
|
||||
self._null_block = NullBlock(
|
||||
self.allocate_mutable(None, Device.GPU))
|
||||
return self._null_block
|
||||
|
||||
def allocate_mutable(self, prev_block: Optional[Block],
|
||||
device: Device) -> Block:
|
||||
"""Allocates a new mutable block on the specified device.
|
||||
@ -149,6 +157,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
Args:
|
||||
block (Block): The block to be freed.
|
||||
"""
|
||||
# Null block should never be freed
|
||||
if isinstance(block, NullBlock):
|
||||
return
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
allocator = self._block_ids_to_allocator[block_id]
|
||||
@ -165,6 +176,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
List[Block]: A new list of blocks that shares the same memory as the
|
||||
original sequence.
|
||||
"""
|
||||
# do not attempt to fork the null block
|
||||
assert not isinstance(last_block, NullBlock)
|
||||
block_id = last_block.block_id
|
||||
assert block_id is not None
|
||||
allocator = self._block_ids_to_allocator[block_id]
|
||||
@ -226,3 +239,64 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NullBlock(Block):
|
||||
"""
|
||||
Null blocks are used as a placeholders for KV cache blocks that have
|
||||
been dropped due to sliding window.
|
||||
This implementation just wraps an ordinary block and prevents it from
|
||||
being modified. It also allows for testing if a block is NullBlock
|
||||
via isinstance().
|
||||
"""
|
||||
|
||||
def __init__(self, proxy: Block):
|
||||
super().__init__()
|
||||
self._proxy = proxy
|
||||
|
||||
def append_token_ids(self, token_ids: List[BlockId]):
|
||||
raise ValueError("null block should not be modified")
|
||||
|
||||
@property
|
||||
def block_id(self):
|
||||
return self._proxy.block_id
|
||||
|
||||
@block_id.setter
|
||||
def block_id(self, value: Optional[BlockId]):
|
||||
raise ValueError("null block should not be modified")
|
||||
|
||||
@property
|
||||
def token_ids(self) -> List[BlockId]:
|
||||
return self._proxy.token_ids
|
||||
|
||||
@property
|
||||
def num_empty_slots(self) -> BlockId:
|
||||
return self._proxy.num_empty_slots
|
||||
|
||||
@property
|
||||
def is_full(self):
|
||||
return self._proxy.is_full
|
||||
|
||||
@property
|
||||
def prev_block(self):
|
||||
return self._proxy.prev_block
|
||||
|
||||
@property
|
||||
def computed(self):
|
||||
return self._proxy.computed
|
||||
|
||||
@computed.setter
|
||||
def computed(self, value):
|
||||
self._proxy.computed = value
|
||||
|
||||
@property
|
||||
def last_accessed(self) -> float:
|
||||
return self._proxy.last_accessed
|
||||
|
||||
@last_accessed.setter
|
||||
def last_accessed(self, last_accessed_ts: float):
|
||||
self._proxy.last_accessed = last_accessed_ts
|
||||
|
||||
@property
|
||||
def content_hash(self):
|
||||
return self._proxy.content_hash
|
||||
|
||||
@ -203,3 +203,12 @@ class DeviceAwareBlockAllocator(ABC):
|
||||
def get_common_computed_block_ids(
|
||||
self, seq_block_ids: List[List[int]]) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_or_get_null_block(self) -> Block:
|
||||
"""
|
||||
Null blocks are used as a placeholders for KV cache blocks that have
|
||||
been dropped due to sliding window.
|
||||
There is at most one null block per allocator.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -66,9 +66,18 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_cpu_blocks
|
||||
|
||||
assert sliding_window is None, "Sliding window not yet supported"
|
||||
|
||||
self.block_sliding_window = None
|
||||
self.sliding_window = sliding_window
|
||||
# max_block_sliding_window is the max number of blocks that need to be
|
||||
# allocated
|
||||
self.max_block_sliding_window = None
|
||||
if sliding_window is not None:
|
||||
# +1 here because // rounds down
|
||||
num_blocks = sliding_window // block_size + 1
|
||||
# +1 here because the last block may not be full,
|
||||
# and so the sequence stretches one more block at the beginning
|
||||
# For example, if sliding_window is 3 and block_size is 4,
|
||||
# we may need 2 blocks when the second block only holds 1 token.
|
||||
self.max_block_sliding_window = num_blocks + 1
|
||||
|
||||
self.watermark = watermark
|
||||
assert watermark >= 0.0
|
||||
@ -96,10 +105,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
block_size=self.block_size,
|
||||
)
|
||||
|
||||
assert self.block_sliding_window is None
|
||||
if self.block_sliding_window is not None:
|
||||
if self.max_block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.block_sliding_window)
|
||||
self.max_block_sliding_window)
|
||||
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
device=Device.GPU)
|
||||
@ -125,8 +133,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
block_table = BlockTable(
|
||||
block_size=self.block_size,
|
||||
block_allocator=self.block_allocator,
|
||||
max_block_sliding_window=self.max_block_sliding_window,
|
||||
)
|
||||
assert self.block_sliding_window is None
|
||||
|
||||
block_table.allocate(seq.get_token_ids())
|
||||
self.block_tables[seq.seq_id] = block_table
|
||||
|
||||
@ -174,6 +183,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
block_table.append_token_ids(
|
||||
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
num_computed_slots=seq.data.get_num_computed_tokens(),
|
||||
)
|
||||
|
||||
# Return any new copy-on-writes.
|
||||
|
||||
@ -648,7 +648,8 @@ class EngineArgs:
|
||||
guided_decoding_backend=self.guided_decoding_backend)
|
||||
|
||||
if (model_config.get_sliding_window() is not None
|
||||
and scheduler_config.chunked_prefill_enabled):
|
||||
and scheduler_config.chunked_prefill_enabled
|
||||
and not scheduler_config.use_v2_block_manager):
|
||||
raise ValueError(
|
||||
"Chunked prefill is not supported with sliding window. "
|
||||
"Set --disable-sliding-window to disable sliding window.")
|
||||
|
||||
@ -68,8 +68,11 @@ class CacheEngine:
|
||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||
kv_cache: List[torch.Tensor] = []
|
||||
for _ in range(self.num_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
kv_cache.append(
|
||||
torch.empty(kv_cache_shape,
|
||||
torch.zeros(kv_cache_shape,
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device))
|
||||
|
||||
@ -269,6 +269,12 @@ class ModelRunner:
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return ModelInput.empty(self.device)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
sliding_window_blocks = (self.sliding_window + self.block_size -
|
||||
1) // self.block_size
|
||||
block_aligned_sliding_window = \
|
||||
sliding_window_blocks * self.block_size
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
@ -309,6 +315,30 @@ class ModelRunner:
|
||||
and self.sliding_window is None
|
||||
and is_prompt)
|
||||
|
||||
# These are seq_len/context_len capped to the sliding window.
|
||||
# They are passed to decode kernel.
|
||||
# We still need original seq_len/context_len to compute slot
|
||||
# mapping (and input position) below.
|
||||
curr_sliding_window_blocks = None
|
||||
sliding_seq_len = seq_len
|
||||
sliding_context_len = context_len
|
||||
|
||||
# TODO(sang): This is a hack to make sliding window work with
|
||||
# paged attn. We can remove it if we make paged attn kernel
|
||||
# to properly handle slinding window attn.
|
||||
if (self.sliding_window is not None and not is_prompt):
|
||||
curr_sliding_window_blocks = sliding_window_blocks
|
||||
if self.scheduler_config.use_v2_block_manager:
|
||||
# number of elements in last block
|
||||
suff_len = seq_len % self.block_size
|
||||
sliding_seq_len = min(
|
||||
seq_len, block_aligned_sliding_window + suff_len)
|
||||
if suff_len > 0:
|
||||
curr_sliding_window_blocks += 1
|
||||
else:
|
||||
sliding_seq_len = min(seq_len, self.sliding_window)
|
||||
sliding_context_len = sliding_seq_len - 1
|
||||
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
@ -316,6 +346,13 @@ class ModelRunner:
|
||||
assert computed_block_nums is not None
|
||||
context_len = len(computed_block_nums) * self.block_size
|
||||
tokens = tokens[context_len:]
|
||||
|
||||
# need to think what to set it to when we have both sliding
|
||||
# window and prefix caching...
|
||||
assert self.sliding_window is None, \
|
||||
"Prefix caching is not supported with sliding window"
|
||||
sliding_context_len = context_len
|
||||
|
||||
if self.attn_backend.get_name() == "flash-attn":
|
||||
# NOTE(woosuk): For flash-attn, the block table should
|
||||
# include the entries for the incoming prefill tokens.
|
||||
@ -329,14 +366,9 @@ class ModelRunner:
|
||||
if seq_group_metadata.block_tables is not None:
|
||||
# chunked prefill or decode
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
if self.sliding_window is not None:
|
||||
# chunked prefill doesn't support sliding window.
|
||||
assert (not self.scheduler_config.
|
||||
chunked_prefill_enabled)
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
|
||||
if curr_sliding_window_blocks is not None:
|
||||
block_table = block_table[
|
||||
-curr_sliding_window_blocks:]
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
paged_kv_indices.extend(block_table)
|
||||
paged_kv_indptr.append(paged_kv_indptr[-1] +
|
||||
@ -354,16 +386,9 @@ class ModelRunner:
|
||||
block_table = []
|
||||
block_tables.append(block_table)
|
||||
|
||||
# TODO(sang): This is a hack to make sliding window work with
|
||||
# paged attn. We can remove it if we make paged attn kernel
|
||||
# to properly handle slinding window attn.
|
||||
if (self.sliding_window is not None and not is_prompt):
|
||||
seq_len = min(seq_len, self.sliding_window)
|
||||
context_len = seq_len - 1
|
||||
|
||||
seq_lens.append(seq_len)
|
||||
context_lens.append(context_len)
|
||||
query_len = seq_len - context_len
|
||||
seq_lens.append(sliding_seq_len)
|
||||
context_lens.append(sliding_context_len)
|
||||
query_len = sliding_seq_len - sliding_context_len
|
||||
query_lens.append(query_len)
|
||||
input_tokens.extend(tokens)
|
||||
input_positions.extend(list(range(context_len, seq_len)))
|
||||
@ -380,16 +405,15 @@ class ModelRunner:
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
num_decode_tokens += query_len
|
||||
decode_seq_lens.append(seq_len)
|
||||
decode_seq_lens.append(sliding_seq_len)
|
||||
|
||||
if lora_id > 0:
|
||||
lora_requests.add(seq_group_metadata.lora_request)
|
||||
|
||||
lora_index_mapping += [lora_id] * (seq_len - context_len)
|
||||
lora_index_mapping += [lora_id] * query_len
|
||||
lora_prompt_mapping.extend(
|
||||
[lora_id] *
|
||||
(seq_len -
|
||||
context_len if seq_group_metadata.sampling_params
|
||||
(query_len if seq_group_metadata.sampling_params
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs
|
||||
else 1))
|
||||
|
||||
@ -417,9 +441,10 @@ class ModelRunner:
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
if is_prompt:
|
||||
assert context_len == 0, (
|
||||
assert self.scheduler_config.use_v2_block_manager \
|
||||
or context_len == 0, (
|
||||
"Prefix caching is currently not supported with "
|
||||
"sliding window attention")
|
||||
"sliding window attention in V1 block manager")
|
||||
# It is an optimization. When it is decoding, it is always
|
||||
# 0. When prefill, we use it to not write slots to kv cache
|
||||
# to save memory.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user