Support beam search & parallel generation (#7)

This commit is contained in:
Woosuk Kwon 2023-03-10 09:58:21 -08:00 committed by GitHub
parent 04e5acc08e
commit 1a7eb7da61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 660 additions and 161 deletions

View File

@ -35,6 +35,10 @@ class LogicalTokenBlock:
def get_token_ids(self) -> List[int]:
return self.token_ids[:self.num_tokens]
def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]
class PhysicalTokenBlock:

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Set, Tuple
from transformers import AutoTokenizer
@ -25,12 +25,35 @@ class Frontend:
def query(
self,
prompt: str,
sampling_params: Optional[SamplingParams] = None,
n: int = 1,
temperature: float = 1.0,
top_p: float = 1.0,
use_beam_search: bool = False,
stop_token_ids: Set[int] = set(),
max_num_steps: int = 16, # From OpenAI API.
num_logprobs: int = 0,
context_window_size: Optional[int] = None,
) -> None:
if sampling_params is None:
sampling_params = SamplingParams()
token_ids: List[int] = self.tokenizer.encode(prompt)
# Stop when we see an EOS token.
stop_token_ids.add(self.tokenizer.eos_token_id)
sampling_params = SamplingParams(
n=n,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop_token_ids=stop_token_ids,
max_num_steps=max_num_steps,
num_logprobs=num_logprobs,
context_window_size=context_window_size,
)
token_ids = self.tokenizer.encode(prompt)
self._add_query(token_ids, sampling_params)
def _add_query(
self,
token_ids: List[int],
sampling_params: SamplingParams,
) -> None:
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)

View File

@ -1,10 +1,12 @@
from typing import Dict, List, Tuple
from typing import Dict, List
from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus
_MAX_NUM_BATCHED_TOKENS = 2048
@ -66,7 +68,7 @@ class Scheduler:
def _append(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
@ -74,7 +76,10 @@ class Scheduler:
ret = self.block_manager.append(seq)
if ret is not None:
src_block, dst_block = ret
blocks_to_copy[src_block] = dst_block
if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]
def _swap_in(
self,
@ -83,9 +88,8 @@ class Scheduler:
) -> None:
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping)
for seq in seq_group.seqs:
if seq.status == SequenceStatus.SWAPPED:
seq.status = SequenceStatus.RUNNING
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
self.running.append(seq_group)
def _swap_out(
@ -96,16 +100,15 @@ class Scheduler:
assert self.block_manager.can_swap_out(seq_group)
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping)
for seq in seq_group.seqs:
if seq.status == SequenceStatus.RUNNING:
seq.status = SequenceStatus.SWAPPED
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
self.swapped.append(seq_group)
def step(self) -> None:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
# 1. Reserve new slots for the running sequences.
# NOTE: Here we implicitly assume FCFS scheduling.
@ -143,6 +146,10 @@ class Scheduler:
# All swapped sequences are swapped in.
self.swapped.clear()
# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out
num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running
@ -152,7 +159,6 @@ class Scheduler:
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
if not self.swapped:
# FIXME(woosuk): Acquire a lock to protect pending.
self._fetch_inputs()
for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len()
@ -168,39 +174,45 @@ class Scheduler:
else:
self.pending.clear()
# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out
# 4. Create input data structures.
prompt_tokens: Dict[int, List[int]] = {}
generation_tokens: Dict[int, int] = {}
context_lens: Dict[int, int] = {}
block_tables: Dict[int, List[int]] = {}
input_seq_groups: List[SequenceGroupInputs] = []
for seq_group in self.running:
group_id = seq_group.group_id
num_steps = self.num_steps[group_id]
# NOTE(woosuk): We assume that the number of steps is 0
# for the prompt sequences.
is_prompt = num_steps == 0
for seq in seq_group.seqs:
if seq.status != SequenceStatus.RUNNING:
continue
input_tokens: Dict[int, List[int]] = {}
seq_logprobs: Dict[int, float] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
block_tables[seq_id] = self.block_manager.get_block_table(seq)
if is_prompt:
prompt_tokens[seq_id] = seq.get_token_ids()
input_tokens[seq_id] = seq.get_token_ids()
else:
generation_tokens[seq_id] = seq.get_token_ids()[-1]
context_lens[seq_id] = seq.get_len()
input_tokens[seq_id] = [seq.get_last_token_id()]
seq_logprobs[seq_id] = seq.cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len = seq.get_len()
input_seq_group = SequenceGroupInputs(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
context_len=seq_len,
seq_logprobs=seq_logprobs,
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)
input_seq_groups.append(input_seq_group)
# 5. Execute the first stage of the pipeline.
self.controllers[0].execute_stage(
prompt_tokens,
generation_tokens,
context_lens,
block_tables,
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
@ -208,7 +220,7 @@ class Scheduler:
def post_step(
self,
next_tokens: Dict[int, Tuple[int, int]],
seq_outputs: Dict[int, SequenceOutputs],
) -> None:
# Update the running sequences and free blocks.
for seq_group in self.running:
@ -216,25 +228,32 @@ class Scheduler:
self.num_steps[group_id] += 1
stop_token_ids = self.sampling_params[group_id].stop_token_ids
# Process beam search results before processing the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
parent_seq_id, next_token = next_tokens[seq.seq_id]
if seq.seq_id != parent_seq_id:
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search).
# Free the current sequence.
self.block_manager.free(seq)
# Fork the parent sequence.
parent_seq = seq_group.find(parent_seq_id)
seq.logical_token_blocks = parent_seq.logical_token_blocks.copy()
parent_seq = seq_group.find(output.parent_seq_id)
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)
# Process the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
# Append a new token to the sequence.
seq.append([next_token])
output = seq_outputs[seq.seq_id]
seq.append(output.output_token, output.logprobs)
# Check if the sequence has generated a stop token.
if next_token in stop_token_ids:
if output.output_token in stop_token_ids:
self._free_seq(seq)
continue

View File

@ -1,8 +1,10 @@
from cacheflow.models.input_metadata import InputMetadata
from cacheflow.models.model_utils import get_model
from cacheflow.models.model_utils import set_seed
__all__ = [
'get_model',
'InputMetadata',
'get_model',
'set_seed'
]

View File

@ -1,21 +1,24 @@
from typing import List
from typing import List, Dict, Tuple
import torch
from cacheflow.sampling_params import SamplingParams
class InputMetadata:
def __init__(
self,
seq_ids: List[int],
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
# FIXME: Rename
max_context_len: int,
block_tables: torch.Tensor,
) -> None:
self.seq_ids = seq_ids
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
@ -23,19 +26,20 @@ class InputMetadata:
self.block_tables = block_tables
self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1]
else:
self.max_num_blocks_per_seq = 0
assert self.num_generation_tokens == block_tables.shape[0]
assert self.num_prompts + self.num_generation_tokens == len(seq_ids)
assert block_tables.shape[0] == self.num_generation_tokens
assert context_lens.shape[0] == self.num_generation_tokens
def __repr__(self) -> str:
return (f'InputMetadata('
f'seq_ids={self.seq_ids}, '
f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '

View File

@ -1,5 +1,7 @@
import random
from typing import Union
import numpy as np
import torch
import torch.nn as nn
@ -30,3 +32,11 @@ def get_model(
model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype)
return model.eval()
raise ValueError(f'Invalid model name: {model_name}')
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)

View File

@ -9,6 +9,7 @@ from transformers import PreTrainedModel
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -261,7 +262,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, Tuple[int, int]]:
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(

View File

@ -4,6 +4,8 @@ import torch
import torch.nn as nn
from cacheflow.models import InputMetadata
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceOutputs
class Sampler(nn.Module):
@ -16,27 +18,266 @@ class Sampler(nn.Module):
embedding: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> Dict[int, Tuple[int, int]]:
# Get the hidden states of the last tokens.
start_idx = 0
last_token_indicies: List[int] = []
for prompt_len in input_metadata.prompt_lens:
last_token_indicies.append(start_idx + prompt_len - 1)
start_idx += prompt_len
last_token_indicies.extend(
range(start_idx, start_idx + input_metadata.num_generation_tokens))
hidden_states = hidden_states[last_token_indicies]
) -> Dict[int, SequenceOutputs]:
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
# Sample the next tokens.
# TODO(woosuk): Implement other sampling methods.
next_token_ids = torch.argmax(logits, dim=-1)
next_token_ids = next_token_ids.tolist()
# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures):
t = torch.tensor(
temperatures, dtype=logits.dtype, device=logits.device)
# Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1))
# Return the next tokens.
next_tokens: Dict[int, Tuple[int, int]] = {}
for seq_id, token_id in zip(input_metadata.seq_ids, next_token_ids):
next_tokens[seq_id] = (seq_id, token_id)
return next_tokens
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p).
logprobs = torch.log(probs)
# Apply top-p truncation.
top_ps = _get_top_ps(input_metadata)
assert len(top_ps) == probs.shape[0]
if any(p < 1.0 for p in top_ps):
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
probs = _apply_top_p(probs, p)
# Sample the next tokens.
return _sample(probs, logprobs, input_metadata)
def _prune_hidden_states(
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
start_idx = 0
last_token_indicies: List[int] = []
for prompt_len in input_metadata.prompt_lens:
last_token_indicies.append(start_idx + prompt_len - 1)
start_idx += prompt_len
last_token_indicies.extend(
range(start_idx, start_idx + input_metadata.num_generation_tokens))
return hidden_states[last_token_indicies]
def _get_temperatures(
input_metadata: InputMetadata,
) -> List[float]:
# Collect the temperatures for the logits.
temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
if temperature == 0.0:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature = 1.0
if i < input_metadata.num_prompts:
# A prompt input.
temperatures.append(temperature)
else:
# A generation token.
temperatures += [temperature] * len(seq_ids)
return temperatures
def _get_top_ps(
input_metadata: InputMetadata,
) -> List[float]:
top_ps: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# A prompt input.
top_ps.append(sampling_params.top_p)
else:
# A generation token.
top_ps += [sampling_params.top_p] * len(seq_ids)
return top_ps
def _apply_top_p(
probs: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
# TODO(woosuk): Optimize.
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
probs = torch.gather(
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
return probs
def _get_topk_logprobs(
logprobs: torch.Tensor,
num_logprobs: int,
) -> Dict[int, float]:
if num_logprobs == 0:
return {}
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
if num_logprobs == 1:
topk_logprobs = [topk_logprobs.item()]
topk_ids = [topk_ids.item()]
else:
topk_logprobs = topk_logprobs.tolist()
topk_ids = topk_ids.tolist()
token_to_logprob: Dict[int, float] = {}
for token_id, logprob in zip(topk_ids, topk_logprobs):
token_to_logprob[token_id] = logprob
return token_to_logprob
def _sample_from_prompt(
prob: torch.Tensor,
sampling_params: SamplingParams,
) -> List[int]:
if sampling_params.use_beam_search:
# Beam search.
beam_width = sampling_params.n
_, next_token_ids = torch.topk(prob, beam_width)
next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature == 0.0:
# Greedy sampling.
assert sampling_params.n == 1
next_token_id = torch.argmax(prob)
next_token_ids = [next_token_id.item()]
else:
# Neucleus sampling.
# Sample n tokens for the prompt.
n = sampling_params.n
next_token_ids = torch.multinomial(
prob, num_samples=n, replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids
def _sample_from_generation_tokens(
seq_ids: List[int],
probs: torch.Tensor,
logprobs: torch.Tensor,
seq_logprobs: List[float],
sampling_params: SamplingParams,
) -> Tuple[List[int], List[int]]:
# NOTE(woosuk): sampling_params.n can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if sampling_params.use_beam_search:
# Beam search.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs = torch.tensor(
seq_logprobs, dtype=torch.float, device=logprobs.device)
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
vocab_size = logprobs.size(-1)
beam_width = len(seq_ids)
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
seq_idx = torch.div(topk_ids, vocab_size, rounding_mode='floor').tolist()
beam_seq_ids = [seq_ids[i] for i in seq_idx]
token_ids = (topk_ids % vocab_size).tolist()
beam_outputs: Dict[int, Tuple[int, int]] = {}
outstanding_beams: List[Tuple[int, int]] = []
# If a beam survives, continue with it.
for seq_id, token_id in zip(beam_seq_ids, token_ids):
if seq_id not in beam_outputs:
beam_outputs[seq_id] = (seq_id, token_id)
else:
outstanding_beams.append((seq_id, token_id))
# If a beam is discarded, fork another beam.
for seq_id in seq_ids:
if seq_id not in beam_outputs:
beam_outputs[seq_id] = outstanding_beams.pop()
assert not outstanding_beams
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
elif sampling_params.temperature == 0.0:
# Greedy sampling.
assert len(seq_ids) == 1
next_token_id = torch.argmax(probs, dim=-1)
next_token_ids = [next_token_id.item()]
parent_seq_ids = seq_ids
else:
# Neucleus sampling.
# Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial(
probs, num_samples=1, replacement=True)
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
parent_seq_ids = seq_ids
return parent_seq_ids, next_token_ids
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]:
seq_outputs: Dict[int, SequenceOutputs] = {}
# TODO(woosuk): Optimize.
idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input.
assert len(seq_ids) == sampling_params.n
prob = probs[idx]
logprob = logprobs[idx]
idx += 1
# Sample the next tokens.
next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(
logprob, sampling_params.num_logprobs)
# Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id, seq_id, next_token_id, output_logprobs)
else:
# Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)]
logprob = logprobs[idx:idx + len(seq_ids)]
idx += len(seq_ids)
# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs: Dict[int, Dict[int, float]] = {}
for i, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs(
logprob[i], sampling_params.num_logprobs)
# Build the output.
for seq_id, parent_seq_id, next_token_id in zip(
seq_ids, parent_seq_ids, next_token_ids):
i = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id,
parent_seq_id,
next_token_id,
output_logprobs,
)
return seq_outputs

View File

@ -5,27 +5,51 @@ class SamplingParams:
def __init__(
self,
n: int = 1,
temperature: float = 1.0,
top_p: float = 1.0,
use_beam_search: bool = False,
stop_token_ids: Set[int] = [],
max_num_steps: int = 16, # From OpenAI API.
max_context_len: Optional[int] = None,
n: int,
temperature: float,
top_p: float,
use_beam_search: bool,
stop_token_ids: Set[int],
max_num_steps: int,
num_logprobs: int,
context_window_size: Optional[int],
) -> None:
assert n >= 1
assert temperature >= 0.0
assert 0.0 < top_p <= 1.0
if n < 1:
raise ValueError(f'n must be at least 1, got {n}.')
if temperature < 0.0:
raise ValueError(
f'temperature must be non-negative, got {temperature}.')
if not 0.0 < top_p <= 1.0:
raise ValueError(f'top_p must be in (0, 1], got {top_p}.')
if max_num_steps < 1:
raise ValueError(
f'max_num_steps must be at least 1, got {max_num_steps}.')
if num_logprobs < 0:
raise ValueError(
f'num_logprobs must be non-negative, got {num_logprobs}.')
if context_window_size is not None and context_window_size < 0:
raise ValueError(
'context_window_size must be non-negative, '
f'got {context_window_size}.')
if use_beam_search:
assert n > 1
assert temperature > 0.0
assert top_p == 1.0
if n == 1:
raise ValueError(
'n must be greater than 1 when using beam search.')
if temperature > 0.0:
raise ValueError(
'temperature must be 0 when using beam search.')
if top_p < 1.0:
raise ValueError(
'top_p must be 1 when using beam search.')
elif temperature == 0.0:
# Zero temperature means greedy decoding.
assert n == 1
assert top_p == 1.0
assert max_num_steps >= 1
assert max_context_len is None or max_context_len >= 0
# Zero temperature means greedy sampling.
if n > 1:
raise ValueError(
'n must be 1 when using greedy sampling.')
if top_p < 1.0:
raise ValueError(
'top_p must be 1 when using greedy sampling.')
self.n = n
self.temperature = temperature
@ -33,4 +57,15 @@ class SamplingParams:
self.use_beam_search = use_beam_search
self.stop_token_ids = stop_token_ids
self.max_num_steps = max_num_steps
self.max_context_len = max_context_len
self.num_logprobs = num_logprobs
self.context_window_size = context_window_size
def __repr__(self) -> str:
return (f'SamplingParams(n={self.n}, '
f'temperature={self.temperature}, '
f'top_p={self.top_p}, '
f'use_beam_search={self.use_beam_search}, '
f'stop_token_ids={self.stop_token_ids}, '
f'max_num_steps={self.max_num_steps}, '
f'num_logprobs={self.num_logprobs}, '
f'context_window_size={self.context_window_size})')

View File

@ -1,7 +1,9 @@
import copy
import enum
from typing import List, Optional
from typing import Dict, List, Optional
from cacheflow.block import LogicalTokenBlock
from cacheflow.sampling_params import SamplingParams
class SequenceStatus(enum.Enum):
@ -24,9 +26,11 @@ class Sequence:
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the given token ids.
self.append(token_ids)
self.add(token_ids)
self.status = SequenceStatus.PENDING
self.output_logprobs: List[Dict[int, float]] = []
self.cumulative_logprobs = 1.0
def add_block(self) -> None:
block = LogicalTokenBlock(
@ -35,7 +39,7 @@ class Sequence:
)
self.logical_token_blocks.append(block)
def append(self, token_ids: List[int]) -> None:
def add(self, token_ids: List[int]) -> None:
while token_ids:
if not self.logical_token_blocks:
self.add_block()
@ -49,6 +53,12 @@ class Sequence:
last_block.append(token_ids[:num_empty_slots])
token_ids = token_ids[num_empty_slots:]
def append(self, token_id: int, logprobs: Dict[int, float]) -> None:
assert token_id in logprobs
self.add([token_id])
self.output_logprobs.append(logprobs)
self.cumulative_logprobs += logprobs[token_id]
def get_len(self) -> int:
return sum(block.num_tokens for block in self.logical_token_blocks)
@ -58,6 +68,14 @@ class Sequence:
token_ids.extend(block.get_token_ids())
return token_ids
def get_last_token_id(self) -> int:
return self.logical_token_blocks[-1].get_last_token_id()
def fork(self, child_seq: 'Sequence') -> 'Sequence':
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.cumulative_logprobs = self.cumulative_logprobs
def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, '
f'status={self.status.name}, '
@ -74,11 +92,17 @@ class SequenceGroup:
self.group_id = group_id
self.seqs = seqs
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
def get_seqs(
self,
status: Optional[SequenceStatus] = None,
) -> List[Sequence]:
if status is None:
return len(self.seqs)
return self.seqs
else:
return len([seq for seq in self.seqs if seq.status == status])
return [seq for seq in self.seqs if seq.status == status]
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status))
def find(self, seq_id: int) -> Sequence:
for seq in self.seqs:
@ -92,3 +116,45 @@ class SequenceGroup:
def __repr__(self) -> str:
return (f'SequenceGroup(group_id={self.group_id}, '
f'num_seqs={len(self.seqs)})')
class SequenceGroupInputs:
def __init__(
self,
group_id: int,
is_prompt: bool,
input_tokens: Dict[int, List[int]], # Seq id -> token ids.
context_len: int,
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
) -> None:
self.group_id = group_id
self.is_prompt = is_prompt
self.input_tokens = input_tokens
self.context_len = context_len
self.seq_logprobs = seq_logprobs
self.sampling_params = sampling_params
self.block_tables = block_tables
class SequenceOutputs:
def __init__(
self,
seq_id: int,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i).
) -> None:
self.seq_id = seq_id
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs
def __repr__(self) -> str:
return (f'SequenceOutputs(seq_id={self.seq_id}, '
f'parent_seq_id={self.parent_seq_id}, '
f'output_token={self.output_token}), '
f'logprobs={self.logprobs}')

View File

@ -97,7 +97,7 @@ class CacheEngine:
cpu_cache.append((key_blocks, value_blocks))
return cpu_cache
def _copy_blocks(
def _swap(
self,
src: List[KVCache],
dst: List[KVCache],
@ -108,19 +108,38 @@ class CacheEngine:
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.copy_cache_blocks(
cache_ops.swap_blocks(
src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
cache_ops.copy_cache_blocks(
cache_ops.swap_blocks(
src_value_cache, dst_value_cache, src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)
def copy(self, src_to_dst: Dict[int, int]) -> None:
self._copy_blocks(self.gpu_cache, self.gpu_cache, src_to_dst)
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
self._copy_blocks(self.cpu_cache, self.gpu_cache, src_to_dst)
self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
self._copy_blocks(self.gpu_cache, self.cpu_cache, src_to_dst)
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
def _copy(
self,
src: List[KVCache],
dst: List[KVCache],
src_to_dsts: Dict[int, List[int]],
) -> None:
with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.copy_blocks(
src_key_cache, dst_key_cache, src_to_dsts)
# Copy the value blocks.
cache_ops.copy_blocks(
src_value_cache, dst_value_cache, src_to_dsts)
event = self.events[i]
event.record(stream=self.cache_stream)
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
self._copy(self.gpu_cache, self.gpu_cache, src_to_dsts)

View File

@ -1,6 +1,7 @@
from typing import Dict, List, Union
from cacheflow.master.scheduler import Scheduler
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.worker.worker import Worker
@ -14,7 +15,8 @@ class Controller:
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
dtype: str = 'half',
dtype: str,
seed: int,
) -> None:
self.node_id = node_id
self.num_workers = num_workers
@ -37,6 +39,7 @@ class Controller:
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
dtype=dtype,
seed=seed,
)
self.workers.append(worker)
@ -49,22 +52,16 @@ class Controller:
def execute_stage(
self,
prompt_tokens: Dict[int, List[int]],
generation_tokens: Dict[int, int],
context_lens: Dict[int, int],
block_tables: Dict[int, List[int]],
input_seq_groups: List[SequenceGroupInputs],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
# FIXME: Support tensor parallelism.
assert len(self.workers) == 1
worker = self.workers[0]
output = worker.execute_stage(
prompt_tokens,
generation_tokens,
context_lens,
block_tables,
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,

View File

@ -1,9 +1,13 @@
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple
import torch
from cacheflow.models import get_model
from cacheflow.models import set_seed
from cacheflow.models import InputMetadata
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
from cacheflow.worker.cache_engine import CacheEngine
@ -18,6 +22,7 @@ class Worker:
num_gpu_blocks: int,
num_cpu_blocks: int,
dtype: str,
seed: int,
) -> None:
self.worker_id = worker_id
self.gpu_id = gpu_id
@ -33,6 +38,11 @@ class Worker:
self.head_size = self.model.config.hidden_size // self.num_heads
self.dtype = self.model.dtype
# Set the seed.
# We set the seed after initializing the model to ensure that
# the random state is not affected by the model initialization.
set_seed(seed)
self.cache_engine = CacheEngine(
worker_id=worker_id,
gpu_id=gpu_id,
@ -49,55 +59,81 @@ class Worker:
def prepare_inputs(
self,
prompt_tokens: Dict[int, List[int]], # Seq id -> List of input token ids.
generation_tokens: Dict[int, int], # Seq id -> Input token id.
context_lens: Dict[int, int], # Seq id -> Number of tokens participating in attention.
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
input_seq_groups: List[SequenceGroupInputs],
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
# TODO(woosuk): Support interactive generation.
# Add the prompt tokens.
prompt_lens: List[int] = []
seq_groups: List[Tuple[List[int], SamplingParams]] = []
seq_logprobs: Dict[int, float] = {}
sampling_params: Dict[int, SamplingParams] = {}
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
prompt_seq_ids = sorted(prompt_tokens.keys())
for seq_id in prompt_seq_ids:
prompt_len = len(prompt_tokens[seq_id])
# Add prompt tokens.
prompt_lens: List[int] = []
for input_seq_group in input_seq_groups:
if not input_seq_group.is_prompt:
continue
seq_ids = list(input_seq_group.input_tokens.keys())
sampling_params = input_seq_group.sampling_params
seq_groups.append((seq_ids, sampling_params))
seq_logprobs.update(input_seq_group.seq_logprobs)
# Use any sequence in the group.
seq_id = seq_ids[0]
prompt_tokens = input_seq_group.input_tokens[seq_id]
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens[seq_id])
input_positions.extend(range(len(prompt_tokens[seq_id])))
input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(range(len(prompt_tokens)))
block_table = block_tables[seq_id]
# Compute the slot mapping.
block_table = input_seq_group.block_tables[seq_id]
for i in range(prompt_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# Add the generation tokens.
# Add generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
context_lens: List[int] = []
generation_block_tables: List[List[int]] = []
for input_seq_group in input_seq_groups:
if input_seq_group.is_prompt:
continue
generation_seq_ids = sorted(generation_tokens.keys())
for seq_id in generation_seq_ids:
input_tokens.append(generation_tokens[seq_id])
position_id = context_lens[seq_id] - 1
input_positions.append(position_id)
seq_ids = list(input_seq_group.input_tokens.keys())
sampling_params = input_seq_group.sampling_params
seq_groups.append((seq_ids, sampling_params))
seq_logprobs.update(input_seq_group.seq_logprobs)
block_table = block_tables[seq_id]
generation_block_tables.append(block_table)
for seq_id in seq_ids:
assert len(input_seq_group.input_tokens[seq_id]) == 1
generation_token = input_seq_group.input_tokens[seq_id][0]
input_tokens.append(generation_token)
max_context_len = max(max_context_len, context_lens[seq_id])
max_num_blocks_per_seq = max(
max_num_blocks_per_seq, len(block_table))
position = input_seq_group.context_len - 1
input_positions.append(position)
block_number = block_table[position_id // self.block_size]
block_offset = position_id % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
block_table = input_seq_group.block_tables[seq_id]
generation_block_tables.append(block_table)
max_context_len = max(
max_context_len, input_seq_group.context_len)
max_num_blocks_per_seq = max(
max_num_blocks_per_seq, len(block_table))
context_lens.append(input_seq_group.context_len)
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# Optimization: Pad the input length to be a multiple of 8.
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
@ -112,8 +148,7 @@ class Worker:
slot_mapping_tensor = torch.tensor(
slot_mapping, dtype=torch.int, device=self.device)
context_lens_tensor = torch.tensor(
[context_lens[seq_id] for seq_id in generation_seq_ids],
dtype=torch.int, device=self.device)
context_lens, dtype=torch.int, device=self.device)
padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables]
@ -121,7 +156,8 @@ class Worker:
padded_block_tables, dtype=torch.int, device=self.device)
input_metadata = InputMetadata(
seq_ids=prompt_seq_ids + generation_seq_ids,
seq_groups=seq_groups,
seq_logprobs=seq_logprobs,
prompt_lens=prompt_lens,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
@ -133,14 +169,11 @@ class Worker:
@torch.inference_mode()
def execute_stage(
self,
prompt_tokens: Dict[int, List[int]], # Seq id -> List of input token ids.
generation_tokens: Dict[int, int], # Seq id -> Input token id.
context_lens: Dict[int, int], # Seq id -> Number of tokens participating in attention.
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
input_seq_groups: List[SequenceGroupInputs],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, int],
) -> Union[torch.Tensor, Dict[int, Tuple[int, int]]]:
blocks_to_copy: Dict[int, List[int]],
) -> Dict[int, SequenceOutputs]:
# Issue cache operations.
command_issued = False
if blocks_to_swap_in:
@ -160,7 +193,7 @@ class Worker:
# Prepare input tensors.
input_tokens, input_positions, input_metadata = self.prepare_inputs(
prompt_tokens, generation_tokens, context_lens, block_tables)
input_seq_groups)
# Execute the model.
output = self.model(

View File

@ -1,9 +1,17 @@
#include <torch/extension.h>
#include <map>
#include <vector>
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
void copy_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
void reshape_and_cache(
torch::Tensor& key,
@ -14,7 +22,11 @@ void reshape_and_cache(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"copy_cache_blocks",
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
m.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
m.def(

View File

@ -5,8 +5,9 @@
#include <algorithm>
#include <cassert>
#include <map>
#include <vector>
void copy_blocks(
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) {
@ -43,6 +44,35 @@ void copy_blocks(
}
}
void copy_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
assert(src_device.is_cuda() && dst_device.is_cuda());
cudaMemcpyKind memcpy_type = cudaMemcpyDeviceToDevice;
void *src_ptr = src.data_ptr();
void *dst_ptr = dst.data_ptr();
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync(
dst_ptr + dst_offset,
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
}
}
}
template<typename scalar_t>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]

View File

@ -15,6 +15,8 @@ parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of
parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
args = parser.parse_args()
@ -30,6 +32,7 @@ def main():
num_gpu_blocks=args.num_gpu_blocks,
num_cpu_blocks=args.num_cpu_blocks,
dtype=args.dtype,
seed=args.seed,
)
controllers.append(controller)
@ -52,18 +55,18 @@ def main():
controllers[i].set_next(controllers[i + 1])
controllers[-1].set_next(scheduler)
# Test the following inputs.
test_inputs = [
'Ion Stoica is a',
'UC Berkeley is',
'The future of cloud computing is',
('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),
('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}),
('The future of cloud computing is', {}), # Use default parameters.
]
# FIXME
while True:
if test_inputs:
frontend.query(test_inputs.pop())
text, sampling_params = test_inputs.pop(0)
frontend.query(text, **sampling_params)
scheduler.step()
if not scheduler.pending and not scheduler.running:
if not (scheduler.pending or scheduler.running or test_inputs):
break