Rename variables and methods (#91)

This commit is contained in:
Woosuk Kwon 2023-05-10 00:58:31 -07:00 committed by GitHub
parent ce26e57fd3
commit 8d66a7b6d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 64 additions and 83 deletions

View File

@ -27,7 +27,7 @@ class LogicalTokenBlock:
def is_full(self) -> bool: def is_full(self) -> bool:
return self.num_tokens == self.block_size return self.num_tokens == self.block_size
def append(self, token_ids: List[int]) -> None: def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots() assert len(token_ids) <= self.get_num_empty_slots()
self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids
self.num_tokens += len(token_ids) self.num_tokens += len(token_ids)

View File

@ -97,15 +97,15 @@ class BlockSpaceManager:
for seq in seq_group.seqs: for seq in seq_group.seqs:
self.block_tables[seq.seq_id] = block_table.copy() self.block_tables[seq.seq_id] = block_table.copy()
def can_append(self, seq_group: SequenceGroup) -> bool: def can_append_slot(self, seq_group: SequenceGroup) -> bool:
# Simple heuristic: If there is at least one free block # Simple heuristic: If there is at least one free block
# for each sequence, we can append. # for each sequence, we can append.
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks return num_seqs <= num_free_gpu_blocks
def append(self, seq: Sequence) -> Optional[Tuple[int, int]]: def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
"""Allocate a physical slot for the new token.""" """Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks logical_blocks = seq.logical_token_blocks
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
@ -156,7 +156,7 @@ class BlockSpaceManager:
num_free_blocks = self.gpu_allocator.get_num_free_blocks() num_free_blocks = self.gpu_allocator.get_num_free_blocks()
# NOTE: Conservatively, we assume that every sequence will allocate # NOTE: Conservatively, we assume that every sequence will allocate
# at least one free block right after the swap-in. # at least one free block right after the swap-in.
# NOTE: This should match the logic in can_append(). # NOTE: This should match the logic in can_append_slot().
num_required_blocks = len(blocks) + num_swapped_seqs num_required_blocks = len(blocks) + num_swapped_seqs
return num_free_blocks - num_required_blocks >= self.watermark_blocks return num_free_blocks - num_required_blocks >= self.watermark_blocks

View File

@ -9,7 +9,7 @@ from cacheflow.core.policy import PolicyFactory
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceGroupInputs from cacheflow.sequence import SequenceGroupMetadata
from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus from cacheflow.sequence import SequenceStatus
@ -105,7 +105,7 @@ class Scheduler:
preempted: List[SequenceGroup] = [] preempted: List[SequenceGroup] = []
while self.running: while self.running:
seq_group = self.running.pop(0) seq_group = self.running.pop(0)
while not self.block_manager.can_append(seq_group): while not self.block_manager.can_append_slot(seq_group):
if self.running: if self.running:
# Preempt the lowest-priority sequence groups. # Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop(-1) victim_seq_group = self.running.pop(-1)
@ -119,7 +119,7 @@ class Scheduler:
break break
else: else:
# Append new slots to the sequence group. # Append new slots to the sequence group.
self._append(seq_group, blocks_to_copy) self._append_slot(seq_group, blocks_to_copy)
running.append(seq_group) running.append(seq_group)
self.running = running self.running = running
@ -143,7 +143,7 @@ class Scheduler:
seq_group = self.swapped.pop(0) seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in) self._swap_in(seq_group, blocks_to_swap_in)
self._append(seq_group, blocks_to_copy) self._append_slot(seq_group, blocks_to_copy)
self.running.append(seq_group) self.running.append(seq_group)
num_batched_tokens = sum( num_batched_tokens = sum(
@ -252,7 +252,7 @@ class Scheduler:
prompt_group_ids = scheduler_output[3] prompt_group_ids = scheduler_output[3]
# Create input data structures. # Create input data structures.
input_seq_groups: List[SequenceGroupInputs] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
updated_seq_groups: List[SequenceGroup] = self.running.copy() updated_seq_groups: List[SequenceGroup] = self.running.copy()
for seq_group in self.running: for seq_group in self.running:
@ -274,7 +274,7 @@ class Scheduler:
# sequence length # sequence length
seq_len = seq.get_len() seq_len = seq.get_len()
input_seq_group = SequenceGroupInputs( seq_group_metadata = SequenceGroupMetadata(
group_id=group_id, group_id=group_id,
is_prompt=is_prompt, is_prompt=is_prompt,
input_tokens=input_tokens, input_tokens=input_tokens,
@ -283,14 +283,14 @@ class Scheduler:
sampling_params=self.sampling_params[group_id], sampling_params=self.sampling_params[group_id],
block_tables=block_tables, block_tables=block_tables,
) )
input_seq_groups.append(input_seq_group) seq_group_metadata_list.append(seq_group_metadata)
# Execute the first stage of the pipeline. # Execute the first stage of the pipeline.
if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out: if seq_group_metadata_list or blocks_to_swap_in or blocks_to_swap_out:
# Swap in and swap out should never happen at the same time. # Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out) assert not (blocks_to_swap_in and blocks_to_swap_out)
self.controllers[0].execute_stage( self.controllers[0].execute_stage(
input_seq_groups, seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
@ -330,7 +330,7 @@ class Scheduler:
# Append a new token to the sequence. # Append a new token to the sequence.
output = seq_outputs[seq.seq_id] output = seq_outputs[seq.seq_id]
seq.append(output.output_token, output.logprobs) seq.append_token(output.output_token, output.logprobs)
# Check if the sequence has generated a stop token. # Check if the sequence has generated a stop token.
if output.output_token in stop_token_ids: if output.output_token in stop_token_ids:
@ -360,13 +360,13 @@ class Scheduler:
if seq_group.group_id not in self.num_steps: if seq_group.group_id not in self.num_steps:
self.num_steps[seq_group.group_id] = 0 self.num_steps[seq_group.group_id] = 0
def _append( def _append_slot(
self, self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
) -> None: ) -> None:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
ret = self.block_manager.append(seq) ret = self.block_manager.append_slot(seq)
if ret is not None: if ret is not None:
src_block, dst_block = ret src_block, dst_block = ret
if src_block in blocks_to_copy: if src_block in blocks_to_copy:

View File

@ -1,4 +1,4 @@
from typing import Optional, Set, Dict from typing import Dict, Set
class SamplingParams: class SamplingParams:
@ -12,7 +12,6 @@ class SamplingParams:
stop_token_ids: Set[int], stop_token_ids: Set[int],
max_num_steps: int, max_num_steps: int,
num_logprobs: int, num_logprobs: int,
context_window_size: Optional[int],
) -> None: ) -> None:
if n < 1: if n < 1:
raise ValueError(f'n must be at least 1, got {n}.') raise ValueError(f'n must be at least 1, got {n}.')
@ -27,10 +26,6 @@ class SamplingParams:
if num_logprobs < 0: if num_logprobs < 0:
raise ValueError( raise ValueError(
f'num_logprobs must be non-negative, got {num_logprobs}.') f'num_logprobs must be non-negative, got {num_logprobs}.')
if context_window_size is not None and context_window_size < 0:
raise ValueError(
'context_window_size must be non-negative, '
f'got {context_window_size}.')
if use_beam_search: if use_beam_search:
if n == 1: if n == 1:
@ -58,7 +53,6 @@ class SamplingParams:
self.stop_token_ids = stop_token_ids self.stop_token_ids = stop_token_ids
self.max_num_steps = max_num_steps self.max_num_steps = max_num_steps
self.num_logprobs = num_logprobs self.num_logprobs = num_logprobs
self.context_window_size = context_window_size
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'SamplingParams(n={self.n}, ' return (f'SamplingParams(n={self.n}, '
@ -67,8 +61,7 @@ class SamplingParams:
f'use_beam_search={self.use_beam_search}, ' f'use_beam_search={self.use_beam_search}, '
f'stop_token_ids={self.stop_token_ids}, ' f'stop_token_ids={self.stop_token_ids}, '
f'max_num_steps={self.max_num_steps}, ' f'max_num_steps={self.max_num_steps}, '
f'num_logprobs={self.num_logprobs}, ' f'num_logprobs={self.num_logprobs}')
f'context_window_size={self.context_window_size})')
@classmethod @classmethod
def from_dict(cls, d: Dict) -> 'SamplingParams': def from_dict(cls, d: Dict) -> 'SamplingParams':
@ -80,5 +73,4 @@ class SamplingParams:
stop_token_ids=set(d.get('stop_token_ids', set())), stop_token_ids=set(d.get('stop_token_ids', set())),
max_num_steps=d.get('max_num_steps', 16), max_num_steps=d.get('max_num_steps', 16),
num_logprobs=d.get('num_logprobs', 0), num_logprobs=d.get('num_logprobs', 0),
context_window_size=d.get('context_window_size', None),
) )

View File

@ -18,45 +18,46 @@ class Sequence:
def __init__( def __init__(
self, self,
seq_id: int, seq_id: int,
token_ids: List[int], prompt_token_ids: List[int],
block_size: int, block_size: int,
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.block_size = block_size self.block_size = block_size
self.prompt_len = len(prompt_token_ids)
self.logical_token_blocks: List[LogicalTokenBlock] = [] self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the given token ids. # Initialize the logical token blocks with the prompt token ids.
self.add(token_ids) self._append_tokens(prompt_token_ids)
self.prompt_len = len(token_ids)
self.status = SequenceStatus.WAITING self.status = SequenceStatus.WAITING
# Used for beam search.
self.output_logprobs: List[Dict[int, float]] = [] self.output_logprobs: List[Dict[int, float]] = []
self.cumulative_logprobs = 0.0 self.cumulative_logprobs = 0.0
def add_block(self) -> None: def _append_logical_block(self) -> None:
block = LogicalTokenBlock( block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks), block_number=len(self.logical_token_blocks),
block_size=self.block_size, block_size=self.block_size,
) )
self.logical_token_blocks.append(block) self.logical_token_blocks.append(block)
def add(self, token_ids: List[int]) -> None: def _append_tokens(self, token_ids: List[int]) -> None:
while token_ids: while token_ids:
if not self.logical_token_blocks: if not self.logical_token_blocks:
self.add_block() self._append_logical_block()
last_block = self.logical_token_blocks[-1] last_block = self.logical_token_blocks[-1]
if last_block.is_full(): if last_block.is_full():
self.add_block() self._append_logical_block()
last_block = self.logical_token_blocks[-1] last_block = self.logical_token_blocks[-1]
num_empty_slots = last_block.get_num_empty_slots() num_empty_slots = last_block.get_num_empty_slots()
last_block.append(token_ids[:num_empty_slots]) last_block.append_tokens(token_ids[:num_empty_slots])
token_ids = token_ids[num_empty_slots:] token_ids = token_ids[num_empty_slots:]
def append(self, token_id: int, logprobs: Dict[int, float]) -> None: def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
assert token_id in logprobs assert token_id in logprobs
self.add([token_id]) self._append_tokens([token_id])
self.output_logprobs.append(logprobs) self.output_logprobs.append(logprobs)
self.cumulative_logprobs += logprobs[token_id] self.cumulative_logprobs += logprobs[token_id]
@ -121,7 +122,7 @@ class SequenceGroup:
f'num_seqs={len(self.seqs)})') f'num_seqs={len(self.seqs)})')
class SequenceGroupInputs: class SequenceGroupMetadata:
def __init__( def __init__(
self, self,

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Union, Tuple, Optional from typing import List, Optional, Tuple, Union
try: try:
import ray import ray
@ -6,7 +6,6 @@ except ImportError:
ray = None ray = None
from cacheflow.core.scheduler import Scheduler from cacheflow.core.scheduler import Scheduler
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.worker.worker import Worker from cacheflow.worker.worker import Worker
@ -81,23 +80,12 @@ class Controller:
self.next_node = next_node self.next_node = next_node
self.is_last_stage = isinstance(next_node, Scheduler) self.is_last_stage = isinstance(next_node, Scheduler)
def execute_stage( def execute_stage(self, *args, **kwargs) -> None:
self,
input_seq_groups: List[SequenceGroupInputs],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
all_outputs = [] all_outputs = []
for worker in self.workers: for worker in self.workers:
executor = (worker.execute_stage.remote executor = (worker.execute_stage.remote
if self.use_ray else worker.execute_stage) if self.use_ray else worker.execute_stage)
output = executor( output = executor(*args, **kwargs)
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
)
all_outputs.append(output) all_outputs.append(output)
if self.use_ray: if self.use_ray:

View File

@ -8,10 +8,11 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
initialize_all_reduce_launcher, initialize_all_reduce_launcher,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroupInputs from cacheflow.sequence import SequenceGroupMetadata
from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceOutputs
from cacheflow.worker.cache_engine import CacheEngine from cacheflow.worker.cache_engine import CacheEngine
class Worker: class Worker:
def __init__( def __init__(
@ -93,30 +94,29 @@ class Worker:
def prepare_inputs( def prepare_inputs(
self, self,
input_seq_groups: List[SequenceGroupInputs], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]: ) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
seq_groups: List[Tuple[List[int], SamplingParams]] = [] seq_groups: List[Tuple[List[int], SamplingParams]] = []
seq_logprobs: Dict[int, float] = {} seq_logprobs: Dict[int, float] = {}
sampling_params: Dict[int, SamplingParams] = {}
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
# Add prompt tokens. # Add prompt tokens.
prompt_lens: List[int] = [] prompt_lens: List[int] = []
for input_seq_group in input_seq_groups: for seq_group_metadata in seq_group_metadata_list:
if not input_seq_group.is_prompt: if not seq_group_metadata.is_prompt:
continue continue
seq_ids = list(input_seq_group.input_tokens.keys()) seq_ids = list(seq_group_metadata.input_tokens.keys())
sampling_params = input_seq_group.sampling_params sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params)) seq_groups.append((seq_ids, sampling_params))
seq_logprobs.update(input_seq_group.seq_logprobs) seq_logprobs.update(seq_group_metadata.seq_logprobs)
# Use any sequence in the group. # Use any sequence in the group.
seq_id = seq_ids[0] seq_id = seq_ids[0]
prompt_tokens = input_seq_group.input_tokens[seq_id] prompt_tokens = seq_group_metadata.input_tokens[seq_id]
prompt_len = len(prompt_tokens) prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
@ -126,7 +126,7 @@ class Worker:
input_positions.extend(range(len(prompt_tokens))) input_positions.extend(range(len(prompt_tokens)))
# Compute the slot mapping. # Compute the slot mapping.
block_table = input_seq_group.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
for i in range(prompt_len): for i in range(prompt_len):
block_number = block_table[i // self.block_size] block_number = block_table[i // self.block_size]
block_offset = i % self.block_size block_offset = i % self.block_size
@ -138,31 +138,31 @@ class Worker:
max_num_blocks_per_seq = 0 max_num_blocks_per_seq = 0
context_lens: List[int] = [] context_lens: List[int] = []
generation_block_tables: List[List[int]] = [] generation_block_tables: List[List[int]] = []
for input_seq_group in input_seq_groups: for seq_group_metadata in seq_group_metadata_list:
if input_seq_group.is_prompt: if seq_group_metadata.is_prompt:
continue continue
seq_ids = list(input_seq_group.input_tokens.keys()) seq_ids = list(seq_group_metadata.input_tokens.keys())
sampling_params = input_seq_group.sampling_params sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params)) seq_groups.append((seq_ids, sampling_params))
seq_logprobs.update(input_seq_group.seq_logprobs) seq_logprobs.update(seq_group_metadata.seq_logprobs)
for seq_id in seq_ids: for seq_id in seq_ids:
assert len(input_seq_group.input_tokens[seq_id]) == 1 assert len(seq_group_metadata.input_tokens[seq_id]) == 1
generation_token = input_seq_group.input_tokens[seq_id][0] generation_token = seq_group_metadata.input_tokens[seq_id][0]
input_tokens.append(generation_token) input_tokens.append(generation_token)
position = input_seq_group.context_len - 1 position = seq_group_metadata.context_len - 1
input_positions.append(position) input_positions.append(position)
block_table = input_seq_group.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
generation_block_tables.append(block_table) generation_block_tables.append(block_table)
max_context_len = max( max_context_len = max(
max_context_len, input_seq_group.context_len) max_context_len, seq_group_metadata.context_len)
max_num_blocks_per_seq = max( max_num_blocks_per_seq = max(
max_num_blocks_per_seq, len(block_table)) max_num_blocks_per_seq, len(block_table))
context_lens.append(input_seq_group.context_len) context_lens.append(seq_group_metadata.context_len)
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
block_offset = position % self.block_size block_offset = position % self.block_size
@ -203,30 +203,30 @@ class Worker:
@torch.inference_mode() @torch.inference_mode()
def execute_stage( def execute_stage(
self, self,
input_seq_groups: List[SequenceGroupInputs], seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
# Issue cache operations. # Issue cache operations.
command_issued = False issued_cache_op = False
if blocks_to_swap_in: if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in) self.cache_engine.swap_in(blocks_to_swap_in)
command_issued = True issued_cache_op = True
if blocks_to_swap_out: if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out) self.cache_engine.swap_out(blocks_to_swap_out)
command_issued = True issued_cache_op = True
if blocks_to_copy: if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy) self.cache_engine.copy(blocks_to_copy)
command_issued = True issued_cache_op = True
if command_issued: if issued_cache_op:
cache_events = self.cache_events cache_events = self.cache_events
else: else:
cache_events = None cache_events = None
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
if not input_seq_groups: if not seq_group_metadata_list:
if cache_events is not None: if cache_events is not None:
for event in cache_events: for event in cache_events:
event.wait() event.wait()
@ -234,7 +234,7 @@ class Worker:
# Prepare input tensors. # Prepare input tensors.
input_tokens, input_positions, input_metadata = self.prepare_inputs( input_tokens, input_positions, input_metadata = self.prepare_inputs(
input_seq_groups) seq_group_metadata_list)
# Execute the model. # Execute the model.
output = self.model( output = self.model(