Rename variables and methods (#91)

This commit is contained in:
Woosuk Kwon 2023-05-10 00:58:31 -07:00 committed by GitHub
parent ce26e57fd3
commit 8d66a7b6d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 64 additions and 83 deletions

View File

@ -27,7 +27,7 @@ class LogicalTokenBlock:
def is_full(self) -> bool:
return self.num_tokens == self.block_size
def append(self, token_ids: List[int]) -> None:
def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots()
self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids
self.num_tokens += len(token_ids)

View File

@ -97,15 +97,15 @@ class BlockSpaceManager:
for seq in seq_group.seqs:
self.block_tables[seq.seq_id] = block_table.copy()
def can_append(self, seq_group: SequenceGroup) -> bool:
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks
def append(self, seq: Sequence) -> Optional[Tuple[int, int]]:
"""Allocate a physical slot for the new token."""
def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
"""Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks
block_table = self.block_tables[seq.seq_id]
@ -156,7 +156,7 @@ class BlockSpaceManager:
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
# NOTE: Conservatively, we assume that every sequence will allocate
# at least one free block right after the swap-in.
# NOTE: This should match the logic in can_append().
# NOTE: This should match the logic in can_append_slot().
num_required_blocks = len(blocks) + num_swapped_seqs
return num_free_blocks - num_required_blocks >= self.watermark_blocks

View File

@ -9,7 +9,7 @@ from cacheflow.core.policy import PolicyFactory
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceGroupMetadata
from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus
@ -105,7 +105,7 @@ class Scheduler:
preempted: List[SequenceGroup] = []
while self.running:
seq_group = self.running.pop(0)
while not self.block_manager.can_append(seq_group):
while not self.block_manager.can_append_slot(seq_group):
if self.running:
# Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop(-1)
@ -119,7 +119,7 @@ class Scheduler:
break
else:
# Append new slots to the sequence group.
self._append(seq_group, blocks_to_copy)
self._append_slot(seq_group, blocks_to_copy)
running.append(seq_group)
self.running = running
@ -143,7 +143,7 @@ class Scheduler:
seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in)
self._append(seq_group, blocks_to_copy)
self._append_slot(seq_group, blocks_to_copy)
self.running.append(seq_group)
num_batched_tokens = sum(
@ -252,7 +252,7 @@ class Scheduler:
prompt_group_ids = scheduler_output[3]
# Create input data structures.
input_seq_groups: List[SequenceGroupInputs] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
updated_seq_groups: List[SequenceGroup] = self.running.copy()
for seq_group in self.running:
@ -274,7 +274,7 @@ class Scheduler:
# sequence length
seq_len = seq.get_len()
input_seq_group = SequenceGroupInputs(
seq_group_metadata = SequenceGroupMetadata(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
@ -283,14 +283,14 @@ class Scheduler:
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)
input_seq_groups.append(input_seq_group)
seq_group_metadata_list.append(seq_group_metadata)
# Execute the first stage of the pipeline.
if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out:
if seq_group_metadata_list or blocks_to_swap_in or blocks_to_swap_out:
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.controllers[0].execute_stage(
input_seq_groups,
seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
@ -330,7 +330,7 @@ class Scheduler:
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append(output.output_token, output.logprobs)
seq.append_token(output.output_token, output.logprobs)
# Check if the sequence has generated a stop token.
if output.output_token in stop_token_ids:
@ -360,13 +360,13 @@ class Scheduler:
if seq_group.group_id not in self.num_steps:
self.num_steps[seq_group.group_id] = 0
def _append(
def _append_slot(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]],
) -> None:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
ret = self.block_manager.append(seq)
ret = self.block_manager.append_slot(seq)
if ret is not None:
src_block, dst_block = ret
if src_block in blocks_to_copy:

View File

@ -1,4 +1,4 @@
from typing import Optional, Set, Dict
from typing import Dict, Set
class SamplingParams:
@ -12,7 +12,6 @@ class SamplingParams:
stop_token_ids: Set[int],
max_num_steps: int,
num_logprobs: int,
context_window_size: Optional[int],
) -> None:
if n < 1:
raise ValueError(f'n must be at least 1, got {n}.')
@ -27,10 +26,6 @@ class SamplingParams:
if num_logprobs < 0:
raise ValueError(
f'num_logprobs must be non-negative, got {num_logprobs}.')
if context_window_size is not None and context_window_size < 0:
raise ValueError(
'context_window_size must be non-negative, '
f'got {context_window_size}.')
if use_beam_search:
if n == 1:
@ -58,7 +53,6 @@ class SamplingParams:
self.stop_token_ids = stop_token_ids
self.max_num_steps = max_num_steps
self.num_logprobs = num_logprobs
self.context_window_size = context_window_size
def __repr__(self) -> str:
return (f'SamplingParams(n={self.n}, '
@ -67,8 +61,7 @@ class SamplingParams:
f'use_beam_search={self.use_beam_search}, '
f'stop_token_ids={self.stop_token_ids}, '
f'max_num_steps={self.max_num_steps}, '
f'num_logprobs={self.num_logprobs}, '
f'context_window_size={self.context_window_size})')
f'num_logprobs={self.num_logprobs}')
@classmethod
def from_dict(cls, d: Dict) -> 'SamplingParams':
@ -80,5 +73,4 @@ class SamplingParams:
stop_token_ids=set(d.get('stop_token_ids', set())),
max_num_steps=d.get('max_num_steps', 16),
num_logprobs=d.get('num_logprobs', 0),
context_window_size=d.get('context_window_size', None),
)

View File

@ -18,45 +18,46 @@ class Sequence:
def __init__(
self,
seq_id: int,
token_ids: List[int],
prompt_token_ids: List[int],
block_size: int,
) -> None:
self.seq_id = seq_id
self.block_size = block_size
self.prompt_len = len(prompt_token_ids)
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the given token ids.
self.add(token_ids)
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens(prompt_token_ids)
self.prompt_len = len(token_ids)
self.status = SequenceStatus.WAITING
# Used for beam search.
self.output_logprobs: List[Dict[int, float]] = []
self.cumulative_logprobs = 0.0
def add_block(self) -> None:
def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks),
block_size=self.block_size,
)
self.logical_token_blocks.append(block)
def add(self, token_ids: List[int]) -> None:
def _append_tokens(self, token_ids: List[int]) -> None:
while token_ids:
if not self.logical_token_blocks:
self.add_block()
self._append_logical_block()
last_block = self.logical_token_blocks[-1]
if last_block.is_full():
self.add_block()
self._append_logical_block()
last_block = self.logical_token_blocks[-1]
num_empty_slots = last_block.get_num_empty_slots()
last_block.append(token_ids[:num_empty_slots])
last_block.append_tokens(token_ids[:num_empty_slots])
token_ids = token_ids[num_empty_slots:]
def append(self, token_id: int, logprobs: Dict[int, float]) -> None:
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
assert token_id in logprobs
self.add([token_id])
self._append_tokens([token_id])
self.output_logprobs.append(logprobs)
self.cumulative_logprobs += logprobs[token_id]
@ -121,7 +122,7 @@ class SequenceGroup:
f'num_seqs={len(self.seqs)})')
class SequenceGroupInputs:
class SequenceGroupMetadata:
def __init__(
self,

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Union, Tuple, Optional
from typing import List, Optional, Tuple, Union
try:
import ray
@ -6,7 +6,6 @@ except ImportError:
ray = None
from cacheflow.core.scheduler import Scheduler
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.worker.worker import Worker
@ -81,23 +80,12 @@ class Controller:
self.next_node = next_node
self.is_last_stage = isinstance(next_node, Scheduler)
def execute_stage(
self,
input_seq_groups: List[SequenceGroupInputs],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
def execute_stage(self, *args, **kwargs) -> None:
all_outputs = []
for worker in self.workers:
executor = (worker.execute_stage.remote
if self.use_ray else worker.execute_stage)
output = executor(
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
)
output = executor(*args, **kwargs)
all_outputs.append(output)
if self.use_ray:

View File

@ -8,10 +8,11 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
initialize_all_reduce_launcher,
get_tensor_model_parallel_world_size)
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceGroupMetadata
from cacheflow.sequence import SequenceOutputs
from cacheflow.worker.cache_engine import CacheEngine
class Worker:
def __init__(
@ -93,30 +94,29 @@ class Worker:
def prepare_inputs(
self,
input_seq_groups: List[SequenceGroupInputs],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
seq_logprobs: Dict[int, float] = {}
sampling_params: Dict[int, SamplingParams] = {}
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
# Add prompt tokens.
prompt_lens: List[int] = []
for input_seq_group in input_seq_groups:
if not input_seq_group.is_prompt:
for seq_group_metadata in seq_group_metadata_list:
if not seq_group_metadata.is_prompt:
continue
seq_ids = list(input_seq_group.input_tokens.keys())
sampling_params = input_seq_group.sampling_params
seq_ids = list(seq_group_metadata.input_tokens.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
seq_logprobs.update(input_seq_group.seq_logprobs)
seq_logprobs.update(seq_group_metadata.seq_logprobs)
# Use any sequence in the group.
seq_id = seq_ids[0]
prompt_tokens = input_seq_group.input_tokens[seq_id]
prompt_tokens = seq_group_metadata.input_tokens[seq_id]
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
@ -126,7 +126,7 @@ class Worker:
input_positions.extend(range(len(prompt_tokens)))
# Compute the slot mapping.
block_table = input_seq_group.block_tables[seq_id]
block_table = seq_group_metadata.block_tables[seq_id]
for i in range(prompt_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
@ -138,31 +138,31 @@ class Worker:
max_num_blocks_per_seq = 0
context_lens: List[int] = []
generation_block_tables: List[List[int]] = []
for input_seq_group in input_seq_groups:
if input_seq_group.is_prompt:
for seq_group_metadata in seq_group_metadata_list:
if seq_group_metadata.is_prompt:
continue
seq_ids = list(input_seq_group.input_tokens.keys())
sampling_params = input_seq_group.sampling_params
seq_ids = list(seq_group_metadata.input_tokens.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
seq_logprobs.update(input_seq_group.seq_logprobs)
seq_logprobs.update(seq_group_metadata.seq_logprobs)
for seq_id in seq_ids:
assert len(input_seq_group.input_tokens[seq_id]) == 1
generation_token = input_seq_group.input_tokens[seq_id][0]
assert len(seq_group_metadata.input_tokens[seq_id]) == 1
generation_token = seq_group_metadata.input_tokens[seq_id][0]
input_tokens.append(generation_token)
position = input_seq_group.context_len - 1
position = seq_group_metadata.context_len - 1
input_positions.append(position)
block_table = input_seq_group.block_tables[seq_id]
block_table = seq_group_metadata.block_tables[seq_id]
generation_block_tables.append(block_table)
max_context_len = max(
max_context_len, input_seq_group.context_len)
max_context_len, seq_group_metadata.context_len)
max_num_blocks_per_seq = max(
max_num_blocks_per_seq, len(block_table))
context_lens.append(input_seq_group.context_len)
context_lens.append(seq_group_metadata.context_len)
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
@ -203,30 +203,30 @@ class Worker:
@torch.inference_mode()
def execute_stage(
self,
input_seq_groups: List[SequenceGroupInputs],
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> Dict[int, SequenceOutputs]:
# Issue cache operations.
command_issued = False
issued_cache_op = False
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in)
command_issued = True
issued_cache_op = True
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
command_issued = True
issued_cache_op = True
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
command_issued = True
issued_cache_op = True
if command_issued:
if issued_cache_op:
cache_events = self.cache_events
else:
cache_events = None
# If there is no input, we don't need to execute the model.
if not input_seq_groups:
if not seq_group_metadata_list:
if cache_events is not None:
for event in cache_events:
event.wait()
@ -234,7 +234,7 @@ class Worker:
# Prepare input tensors.
input_tokens, input_positions, input_metadata = self.prepare_inputs(
input_seq_groups)
seq_group_metadata_list)
# Execute the model.
output = self.model(