Support beam search & parallel generation (#7)

This commit is contained in:
Woosuk Kwon 2023-03-10 09:58:21 -08:00 committed by GitHub
parent 04e5acc08e
commit 1a7eb7da61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 660 additions and 161 deletions

View File

@ -35,6 +35,10 @@ class LogicalTokenBlock:
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> List[int]:
return self.token_ids[:self.num_tokens] return self.token_ids[:self.num_tokens]
def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]
class PhysicalTokenBlock: class PhysicalTokenBlock:

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple from typing import List, Optional, Set, Tuple
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -25,12 +25,35 @@ class Frontend:
def query( def query(
self, self,
prompt: str, prompt: str,
sampling_params: Optional[SamplingParams] = None, n: int = 1,
temperature: float = 1.0,
top_p: float = 1.0,
use_beam_search: bool = False,
stop_token_ids: Set[int] = set(),
max_num_steps: int = 16, # From OpenAI API.
num_logprobs: int = 0,
context_window_size: Optional[int] = None,
) -> None: ) -> None:
if sampling_params is None: # Stop when we see an EOS token.
sampling_params = SamplingParams() stop_token_ids.add(self.tokenizer.eos_token_id)
token_ids: List[int] = self.tokenizer.encode(prompt) sampling_params = SamplingParams(
n=n,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop_token_ids=stop_token_ids,
max_num_steps=max_num_steps,
num_logprobs=num_logprobs,
context_window_size=context_window_size,
)
token_ids = self.tokenizer.encode(prompt)
self._add_query(token_ids, sampling_params)
def _add_query(
self,
token_ids: List[int],
sampling_params: SamplingParams,
) -> None:
seqs: List[Sequence] = [] seqs: List[Sequence] = []
for _ in range(sampling_params.n): for _ in range(sampling_params.n):
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)

View File

@ -1,10 +1,12 @@
from typing import Dict, List, Tuple from typing import Dict, List
from cacheflow.master.block_manager import BlockSpaceManager from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.frontend import Frontend from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus from cacheflow.sequence import SequenceStatus
_MAX_NUM_BATCHED_TOKENS = 2048 _MAX_NUM_BATCHED_TOKENS = 2048
@ -66,7 +68,7 @@ class Scheduler:
def _append( def _append(
self, self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
blocks_to_copy: Dict[int, int], blocks_to_copy: Dict[int, List[int]],
) -> None: ) -> None:
for seq in seq_group.seqs: for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED: if seq.status == SequenceStatus.FINISHED:
@ -74,7 +76,10 @@ class Scheduler:
ret = self.block_manager.append(seq) ret = self.block_manager.append(seq)
if ret is not None: if ret is not None:
src_block, dst_block = ret src_block, dst_block = ret
blocks_to_copy[src_block] = dst_block if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]
def _swap_in( def _swap_in(
self, self,
@ -83,9 +88,8 @@ class Scheduler:
) -> None: ) -> None:
mapping = self.block_manager.swap_in(seq_group) mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping) blocks_to_swap_in.update(mapping)
for seq in seq_group.seqs: for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
if seq.status == SequenceStatus.SWAPPED: seq.status = SequenceStatus.RUNNING
seq.status = SequenceStatus.RUNNING
self.running.append(seq_group) self.running.append(seq_group)
def _swap_out( def _swap_out(
@ -96,16 +100,15 @@ class Scheduler:
assert self.block_manager.can_swap_out(seq_group) assert self.block_manager.can_swap_out(seq_group)
mapping = self.block_manager.swap_out(seq_group) mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping) blocks_to_swap_out.update(mapping)
for seq in seq_group.seqs: for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
if seq.status == SequenceStatus.RUNNING: seq.status = SequenceStatus.SWAPPED
seq.status = SequenceStatus.SWAPPED
self.swapped.append(seq_group) self.swapped.append(seq_group)
def step(self) -> None: def step(self) -> None:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, int] = {} blocks_to_copy: Dict[int, List[int]] = {}
# 1. Reserve new slots for the running sequences. # 1. Reserve new slots for the running sequences.
# NOTE: Here we implicitly assume FCFS scheduling. # NOTE: Here we implicitly assume FCFS scheduling.
@ -143,6 +146,10 @@ class Scheduler:
# All swapped sequences are swapped in. # All swapped sequences are swapped in.
self.swapped.clear() self.swapped.clear()
# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out
num_batched_tokens = sum( num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING) seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running for seq_group in self.running
@ -152,7 +159,6 @@ class Scheduler:
# NOTE: Here we implicitly assume FCFS scheduling. # NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size. # TODO(woosuk): Add a batching policy to control the batch size.
if not self.swapped: if not self.swapped:
# FIXME(woosuk): Acquire a lock to protect pending.
self._fetch_inputs() self._fetch_inputs()
for i, seq_group in enumerate(self.pending): for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len() num_prompt_tokens = seq_group.seqs[0].get_len()
@ -168,39 +174,45 @@ class Scheduler:
else: else:
self.pending.clear() self.pending.clear()
# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out
# 4. Create input data structures. # 4. Create input data structures.
prompt_tokens: Dict[int, List[int]] = {} input_seq_groups: List[SequenceGroupInputs] = []
generation_tokens: Dict[int, int] = {}
context_lens: Dict[int, int] = {}
block_tables: Dict[int, List[int]] = {}
for seq_group in self.running: for seq_group in self.running:
group_id = seq_group.group_id group_id = seq_group.group_id
num_steps = self.num_steps[group_id] num_steps = self.num_steps[group_id]
# NOTE(woosuk): We assume that the number of steps is 0 # NOTE(woosuk): We assume that the number of steps is 0
# for the prompt sequences. # for the prompt sequences.
is_prompt = num_steps == 0 is_prompt = num_steps == 0
for seq in seq_group.seqs:
if seq.status != SequenceStatus.RUNNING:
continue
input_tokens: Dict[int, List[int]] = {}
seq_logprobs: Dict[int, float] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id seq_id = seq.seq_id
block_tables[seq_id] = self.block_manager.get_block_table(seq) block_tables[seq_id] = self.block_manager.get_block_table(seq)
if is_prompt: if is_prompt:
prompt_tokens[seq_id] = seq.get_token_ids() input_tokens[seq_id] = seq.get_token_ids()
else: else:
generation_tokens[seq_id] = seq.get_token_ids()[-1] input_tokens[seq_id] = [seq.get_last_token_id()]
context_lens[seq_id] = seq.get_len() seq_logprobs[seq_id] = seq.cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len = seq.get_len()
input_seq_group = SequenceGroupInputs(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
context_len=seq_len,
seq_logprobs=seq_logprobs,
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)
input_seq_groups.append(input_seq_group)
# 5. Execute the first stage of the pipeline. # 5. Execute the first stage of the pipeline.
self.controllers[0].execute_stage( self.controllers[0].execute_stage(
prompt_tokens, input_seq_groups,
generation_tokens,
context_lens,
block_tables,
blocks_to_swap_in, blocks_to_swap_in,
blocks_to_swap_out, blocks_to_swap_out,
blocks_to_copy, blocks_to_copy,
@ -208,7 +220,7 @@ class Scheduler:
def post_step( def post_step(
self, self,
next_tokens: Dict[int, Tuple[int, int]], seq_outputs: Dict[int, SequenceOutputs],
) -> None: ) -> None:
# Update the running sequences and free blocks. # Update the running sequences and free blocks.
for seq_group in self.running: for seq_group in self.running:
@ -216,25 +228,32 @@ class Scheduler:
self.num_steps[group_id] += 1 self.num_steps[group_id] += 1
stop_token_ids = self.sampling_params[group_id].stop_token_ids stop_token_ids = self.sampling_params[group_id].stop_token_ids
# Process beam search results before processing the next tokens.
for seq in seq_group.seqs: for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED: if seq.status == SequenceStatus.FINISHED:
continue continue
parent_seq_id, next_token = next_tokens[seq.seq_id] output = seq_outputs[seq.seq_id]
if seq.seq_id != parent_seq_id: if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search). # The sequence is a fork of the parent sequence (beam search).
# Free the current sequence. # Free the current sequence.
self.block_manager.free(seq) self.block_manager.free(seq)
# Fork the parent sequence. # Fork the parent sequence.
parent_seq = seq_group.find(parent_seq_id) parent_seq = seq_group.find(output.parent_seq_id)
seq.logical_token_blocks = parent_seq.logical_token_blocks.copy() parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq) self.block_manager.fork(parent_seq, seq)
# Process the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
# Append a new token to the sequence. # Append a new token to the sequence.
seq.append([next_token]) output = seq_outputs[seq.seq_id]
seq.append(output.output_token, output.logprobs)
# Check if the sequence has generated a stop token. # Check if the sequence has generated a stop token.
if next_token in stop_token_ids: if output.output_token in stop_token_ids:
self._free_seq(seq) self._free_seq(seq)
continue continue

View File

@ -1,8 +1,10 @@
from cacheflow.models.input_metadata import InputMetadata from cacheflow.models.input_metadata import InputMetadata
from cacheflow.models.model_utils import get_model from cacheflow.models.model_utils import get_model
from cacheflow.models.model_utils import set_seed
__all__ = [ __all__ = [
'get_model',
'InputMetadata', 'InputMetadata',
'get_model',
'set_seed'
] ]

View File

@ -1,21 +1,24 @@
from typing import List from typing import List, Dict, Tuple
import torch import torch
from cacheflow.sampling_params import SamplingParams
class InputMetadata: class InputMetadata:
def __init__( def __init__(
self, self,
seq_ids: List[int], seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int], prompt_lens: List[int],
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
context_lens: torch.Tensor, context_lens: torch.Tensor,
# FIXME: Rename
max_context_len: int, max_context_len: int,
block_tables: torch.Tensor, block_tables: torch.Tensor,
) -> None: ) -> None:
self.seq_ids = seq_ids self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens self.prompt_lens = prompt_lens
self.slot_mapping = slot_mapping self.slot_mapping = slot_mapping
self.context_lens = context_lens self.context_lens = context_lens
@ -23,19 +26,20 @@ class InputMetadata:
self.block_tables = block_tables self.block_tables = block_tables
self.num_prompts = len(prompt_lens) self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.num_generation_tokens = context_lens.shape[0] self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0] self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0: if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1] self.max_num_blocks_per_seq = block_tables.shape[1]
else: else:
self.max_num_blocks_per_seq = 0 self.max_num_blocks_per_seq = 0
assert self.num_generation_tokens == block_tables.shape[0] assert block_tables.shape[0] == self.num_generation_tokens
assert self.num_prompts + self.num_generation_tokens == len(seq_ids) assert context_lens.shape[0] == self.num_generation_tokens
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'InputMetadata(' return (f'InputMetadata('
f'seq_ids={self.seq_ids}, '
f'num_prompts={self.num_prompts}, ' f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_generation_tokens={self.num_generation_tokens}, ' f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, ' f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '

View File

@ -1,5 +1,7 @@
import random
from typing import Union from typing import Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -30,3 +32,11 @@ def get_model(
model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype) model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype)
return model.eval() return model.eval()
raise ValueError(f'Invalid model name: {model_name}') raise ValueError(f'Invalid model name: {model_name}')
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)

View File

@ -9,6 +9,7 @@ from transformers import PreTrainedModel
from cacheflow.models import InputMetadata from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler from cacheflow.models.sample import Sampler
from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -261,7 +262,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, Tuple[int, int]]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.model( hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events) input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(

View File

@ -4,6 +4,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from cacheflow.models import InputMetadata from cacheflow.models import InputMetadata
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceOutputs
class Sampler(nn.Module): class Sampler(nn.Module):
@ -16,27 +18,266 @@ class Sampler(nn.Module):
embedding: torch.Tensor, embedding: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> Dict[int, Tuple[int, int]]: ) -> Dict[int, SequenceOutputs]:
# Get the hidden states of the last tokens. # Get the hidden states that we use for sampling.
start_idx = 0 hidden_states = _prune_hidden_states(hidden_states, input_metadata)
last_token_indicies: List[int] = []
for prompt_len in input_metadata.prompt_lens:
last_token_indicies.append(start_idx + prompt_len - 1)
start_idx += prompt_len
last_token_indicies.extend(
range(start_idx, start_idx + input_metadata.num_generation_tokens))
hidden_states = hidden_states[last_token_indicies]
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t()) logits = torch.matmul(hidden_states, embedding.t())
# Sample the next tokens. # Apply temperature scaling.
# TODO(woosuk): Implement other sampling methods. temperatures = _get_temperatures(input_metadata)
next_token_ids = torch.argmax(logits, dim=-1) assert len(temperatures) == logits.shape[0]
next_token_ids = next_token_ids.tolist() if any(t != 1.0 for t in temperatures):
t = torch.tensor(
temperatures, dtype=logits.dtype, device=logits.device)
# Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1))
# Return the next tokens. # Compute the probabilities.
next_tokens: Dict[int, Tuple[int, int]] = {} probs = torch.softmax(logits, dim=-1, dtype=torch.float)
for seq_id, token_id in zip(input_metadata.seq_ids, next_token_ids): # Compute the log probabilities (before applying top-p).
next_tokens[seq_id] = (seq_id, token_id) logprobs = torch.log(probs)
return next_tokens
# Apply top-p truncation.
top_ps = _get_top_ps(input_metadata)
assert len(top_ps) == probs.shape[0]
if any(p < 1.0 for p in top_ps):
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
probs = _apply_top_p(probs, p)
# Sample the next tokens.
return _sample(probs, logprobs, input_metadata)
def _prune_hidden_states(
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
start_idx = 0
last_token_indicies: List[int] = []
for prompt_len in input_metadata.prompt_lens:
last_token_indicies.append(start_idx + prompt_len - 1)
start_idx += prompt_len
last_token_indicies.extend(
range(start_idx, start_idx + input_metadata.num_generation_tokens))
return hidden_states[last_token_indicies]
def _get_temperatures(
input_metadata: InputMetadata,
) -> List[float]:
# Collect the temperatures for the logits.
temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
if temperature == 0.0:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature = 1.0
if i < input_metadata.num_prompts:
# A prompt input.
temperatures.append(temperature)
else:
# A generation token.
temperatures += [temperature] * len(seq_ids)
return temperatures
def _get_top_ps(
input_metadata: InputMetadata,
) -> List[float]:
top_ps: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# A prompt input.
top_ps.append(sampling_params.top_p)
else:
# A generation token.
top_ps += [sampling_params.top_p] * len(seq_ids)
return top_ps
def _apply_top_p(
probs: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
# TODO(woosuk): Optimize.
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
probs = torch.gather(
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
return probs
def _get_topk_logprobs(
logprobs: torch.Tensor,
num_logprobs: int,
) -> Dict[int, float]:
if num_logprobs == 0:
return {}
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
if num_logprobs == 1:
topk_logprobs = [topk_logprobs.item()]
topk_ids = [topk_ids.item()]
else:
topk_logprobs = topk_logprobs.tolist()
topk_ids = topk_ids.tolist()
token_to_logprob: Dict[int, float] = {}
for token_id, logprob in zip(topk_ids, topk_logprobs):
token_to_logprob[token_id] = logprob
return token_to_logprob
def _sample_from_prompt(
prob: torch.Tensor,
sampling_params: SamplingParams,
) -> List[int]:
if sampling_params.use_beam_search:
# Beam search.
beam_width = sampling_params.n
_, next_token_ids = torch.topk(prob, beam_width)
next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature == 0.0:
# Greedy sampling.
assert sampling_params.n == 1
next_token_id = torch.argmax(prob)
next_token_ids = [next_token_id.item()]
else:
# Neucleus sampling.
# Sample n tokens for the prompt.
n = sampling_params.n
next_token_ids = torch.multinomial(
prob, num_samples=n, replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids
def _sample_from_generation_tokens(
seq_ids: List[int],
probs: torch.Tensor,
logprobs: torch.Tensor,
seq_logprobs: List[float],
sampling_params: SamplingParams,
) -> Tuple[List[int], List[int]]:
# NOTE(woosuk): sampling_params.n can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if sampling_params.use_beam_search:
# Beam search.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs = torch.tensor(
seq_logprobs, dtype=torch.float, device=logprobs.device)
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
vocab_size = logprobs.size(-1)
beam_width = len(seq_ids)
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
seq_idx = torch.div(topk_ids, vocab_size, rounding_mode='floor').tolist()
beam_seq_ids = [seq_ids[i] for i in seq_idx]
token_ids = (topk_ids % vocab_size).tolist()
beam_outputs: Dict[int, Tuple[int, int]] = {}
outstanding_beams: List[Tuple[int, int]] = []
# If a beam survives, continue with it.
for seq_id, token_id in zip(beam_seq_ids, token_ids):
if seq_id not in beam_outputs:
beam_outputs[seq_id] = (seq_id, token_id)
else:
outstanding_beams.append((seq_id, token_id))
# If a beam is discarded, fork another beam.
for seq_id in seq_ids:
if seq_id not in beam_outputs:
beam_outputs[seq_id] = outstanding_beams.pop()
assert not outstanding_beams
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
elif sampling_params.temperature == 0.0:
# Greedy sampling.
assert len(seq_ids) == 1
next_token_id = torch.argmax(probs, dim=-1)
next_token_ids = [next_token_id.item()]
parent_seq_ids = seq_ids
else:
# Neucleus sampling.
# Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial(
probs, num_samples=1, replacement=True)
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
parent_seq_ids = seq_ids
return parent_seq_ids, next_token_ids
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]:
seq_outputs: Dict[int, SequenceOutputs] = {}
# TODO(woosuk): Optimize.
idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input.
assert len(seq_ids) == sampling_params.n
prob = probs[idx]
logprob = logprobs[idx]
idx += 1
# Sample the next tokens.
next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(
logprob, sampling_params.num_logprobs)
# Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id, seq_id, next_token_id, output_logprobs)
else:
# Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)]
logprob = logprobs[idx:idx + len(seq_ids)]
idx += len(seq_ids)
# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs: Dict[int, Dict[int, float]] = {}
for i, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs(
logprob[i], sampling_params.num_logprobs)
# Build the output.
for seq_id, parent_seq_id, next_token_id in zip(
seq_ids, parent_seq_ids, next_token_ids):
i = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id,
parent_seq_id,
next_token_id,
output_logprobs,
)
return seq_outputs

View File

@ -5,27 +5,51 @@ class SamplingParams:
def __init__( def __init__(
self, self,
n: int = 1, n: int,
temperature: float = 1.0, temperature: float,
top_p: float = 1.0, top_p: float,
use_beam_search: bool = False, use_beam_search: bool,
stop_token_ids: Set[int] = [], stop_token_ids: Set[int],
max_num_steps: int = 16, # From OpenAI API. max_num_steps: int,
max_context_len: Optional[int] = None, num_logprobs: int,
context_window_size: Optional[int],
) -> None: ) -> None:
assert n >= 1 if n < 1:
assert temperature >= 0.0 raise ValueError(f'n must be at least 1, got {n}.')
assert 0.0 < top_p <= 1.0 if temperature < 0.0:
raise ValueError(
f'temperature must be non-negative, got {temperature}.')
if not 0.0 < top_p <= 1.0:
raise ValueError(f'top_p must be in (0, 1], got {top_p}.')
if max_num_steps < 1:
raise ValueError(
f'max_num_steps must be at least 1, got {max_num_steps}.')
if num_logprobs < 0:
raise ValueError(
f'num_logprobs must be non-negative, got {num_logprobs}.')
if context_window_size is not None and context_window_size < 0:
raise ValueError(
'context_window_size must be non-negative, '
f'got {context_window_size}.')
if use_beam_search: if use_beam_search:
assert n > 1 if n == 1:
assert temperature > 0.0 raise ValueError(
assert top_p == 1.0 'n must be greater than 1 when using beam search.')
if temperature > 0.0:
raise ValueError(
'temperature must be 0 when using beam search.')
if top_p < 1.0:
raise ValueError(
'top_p must be 1 when using beam search.')
elif temperature == 0.0: elif temperature == 0.0:
# Zero temperature means greedy decoding. # Zero temperature means greedy sampling.
assert n == 1 if n > 1:
assert top_p == 1.0 raise ValueError(
assert max_num_steps >= 1 'n must be 1 when using greedy sampling.')
assert max_context_len is None or max_context_len >= 0 if top_p < 1.0:
raise ValueError(
'top_p must be 1 when using greedy sampling.')
self.n = n self.n = n
self.temperature = temperature self.temperature = temperature
@ -33,4 +57,15 @@ class SamplingParams:
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.stop_token_ids = stop_token_ids self.stop_token_ids = stop_token_ids
self.max_num_steps = max_num_steps self.max_num_steps = max_num_steps
self.max_context_len = max_context_len self.num_logprobs = num_logprobs
self.context_window_size = context_window_size
def __repr__(self) -> str:
return (f'SamplingParams(n={self.n}, '
f'temperature={self.temperature}, '
f'top_p={self.top_p}, '
f'use_beam_search={self.use_beam_search}, '
f'stop_token_ids={self.stop_token_ids}, '
f'max_num_steps={self.max_num_steps}, '
f'num_logprobs={self.num_logprobs}, '
f'context_window_size={self.context_window_size})')

View File

@ -1,7 +1,9 @@
import copy
import enum import enum
from typing import List, Optional from typing import Dict, List, Optional
from cacheflow.block import LogicalTokenBlock from cacheflow.block import LogicalTokenBlock
from cacheflow.sampling_params import SamplingParams
class SequenceStatus(enum.Enum): class SequenceStatus(enum.Enum):
@ -24,9 +26,11 @@ class Sequence:
self.logical_token_blocks: List[LogicalTokenBlock] = [] self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the given token ids. # Initialize the logical token blocks with the given token ids.
self.append(token_ids) self.add(token_ids)
self.status = SequenceStatus.PENDING self.status = SequenceStatus.PENDING
self.output_logprobs: List[Dict[int, float]] = []
self.cumulative_logprobs = 1.0
def add_block(self) -> None: def add_block(self) -> None:
block = LogicalTokenBlock( block = LogicalTokenBlock(
@ -35,7 +39,7 @@ class Sequence:
) )
self.logical_token_blocks.append(block) self.logical_token_blocks.append(block)
def append(self, token_ids: List[int]) -> None: def add(self, token_ids: List[int]) -> None:
while token_ids: while token_ids:
if not self.logical_token_blocks: if not self.logical_token_blocks:
self.add_block() self.add_block()
@ -49,6 +53,12 @@ class Sequence:
last_block.append(token_ids[:num_empty_slots]) last_block.append(token_ids[:num_empty_slots])
token_ids = token_ids[num_empty_slots:] token_ids = token_ids[num_empty_slots:]
def append(self, token_id: int, logprobs: Dict[int, float]) -> None:
assert token_id in logprobs
self.add([token_id])
self.output_logprobs.append(logprobs)
self.cumulative_logprobs += logprobs[token_id]
def get_len(self) -> int: def get_len(self) -> int:
return sum(block.num_tokens for block in self.logical_token_blocks) return sum(block.num_tokens for block in self.logical_token_blocks)
@ -58,6 +68,14 @@ class Sequence:
token_ids.extend(block.get_token_ids()) token_ids.extend(block.get_token_ids())
return token_ids return token_ids
def get_last_token_id(self) -> int:
return self.logical_token_blocks[-1].get_last_token_id()
def fork(self, child_seq: 'Sequence') -> 'Sequence':
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.cumulative_logprobs = self.cumulative_logprobs
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, ' return (f'Sequence(seq_id={self.seq_id}, '
f'status={self.status.name}, ' f'status={self.status.name}, '
@ -74,11 +92,17 @@ class SequenceGroup:
self.group_id = group_id self.group_id = group_id
self.seqs = seqs self.seqs = seqs
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: def get_seqs(
self,
status: Optional[SequenceStatus] = None,
) -> List[Sequence]:
if status is None: if status is None:
return len(self.seqs) return self.seqs
else: else:
return len([seq for seq in self.seqs if seq.status == status]) return [seq for seq in self.seqs if seq.status == status]
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status))
def find(self, seq_id: int) -> Sequence: def find(self, seq_id: int) -> Sequence:
for seq in self.seqs: for seq in self.seqs:
@ -92,3 +116,45 @@ class SequenceGroup:
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'SequenceGroup(group_id={self.group_id}, ' return (f'SequenceGroup(group_id={self.group_id}, '
f'num_seqs={len(self.seqs)})') f'num_seqs={len(self.seqs)})')
class SequenceGroupInputs:
def __init__(
self,
group_id: int,
is_prompt: bool,
input_tokens: Dict[int, List[int]], # Seq id -> token ids.
context_len: int,
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
) -> None:
self.group_id = group_id
self.is_prompt = is_prompt
self.input_tokens = input_tokens
self.context_len = context_len
self.seq_logprobs = seq_logprobs
self.sampling_params = sampling_params
self.block_tables = block_tables
class SequenceOutputs:
def __init__(
self,
seq_id: int,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i).
) -> None:
self.seq_id = seq_id
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs
def __repr__(self) -> str:
return (f'SequenceOutputs(seq_id={self.seq_id}, '
f'parent_seq_id={self.parent_seq_id}, '
f'output_token={self.output_token}), '
f'logprobs={self.logprobs}')

View File

@ -97,7 +97,7 @@ class CacheEngine:
cpu_cache.append((key_blocks, value_blocks)) cpu_cache.append((key_blocks, value_blocks))
return cpu_cache return cpu_cache
def _copy_blocks( def _swap(
self, self,
src: List[KVCache], src: List[KVCache],
dst: List[KVCache], dst: List[KVCache],
@ -108,19 +108,38 @@ class CacheEngine:
src_key_cache, src_value_cache = src[i] src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i] dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks. # Copy the key blocks.
cache_ops.copy_cache_blocks( cache_ops.swap_blocks(
src_key_cache, dst_key_cache, src_to_dst) src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks. # Copy the value blocks.
cache_ops.copy_cache_blocks( cache_ops.swap_blocks(
src_value_cache, dst_value_cache, src_to_dst) src_value_cache, dst_value_cache, src_to_dst)
event = self.events[i] event = self.events[i]
event.record(stream=self.cache_stream) event.record(stream=self.cache_stream)
def copy(self, src_to_dst: Dict[int, int]) -> None:
self._copy_blocks(self.gpu_cache, self.gpu_cache, src_to_dst)
def swap_in(self, src_to_dst: Dict[int, int]) -> None: def swap_in(self, src_to_dst: Dict[int, int]) -> None:
self._copy_blocks(self.cpu_cache, self.gpu_cache, src_to_dst) self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
def swap_out(self, src_to_dst: Dict[int, int]) -> None: def swap_out(self, src_to_dst: Dict[int, int]) -> None:
self._copy_blocks(self.gpu_cache, self.cpu_cache, src_to_dst) self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
def _copy(
self,
src: List[KVCache],
dst: List[KVCache],
src_to_dsts: Dict[int, List[int]],
) -> None:
with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.copy_blocks(
src_key_cache, dst_key_cache, src_to_dsts)
# Copy the value blocks.
cache_ops.copy_blocks(
src_value_cache, dst_value_cache, src_to_dsts)
event = self.events[i]
event.record(stream=self.cache_stream)
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
self._copy(self.gpu_cache, self.gpu_cache, src_to_dsts)

View File

@ -1,6 +1,7 @@
from typing import Dict, List, Union from typing import Dict, List, Union
from cacheflow.master.scheduler import Scheduler from cacheflow.master.scheduler import Scheduler
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.worker.worker import Worker from cacheflow.worker.worker import Worker
@ -14,7 +15,8 @@ class Controller:
block_size: int, block_size: int,
num_gpu_blocks: int, num_gpu_blocks: int,
num_cpu_blocks: int, num_cpu_blocks: int,
dtype: str = 'half', dtype: str,
seed: int,
) -> None: ) -> None:
self.node_id = node_id self.node_id = node_id
self.num_workers = num_workers self.num_workers = num_workers
@ -37,6 +39,7 @@ class Controller:
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
dtype=dtype, dtype=dtype,
seed=seed,
) )
self.workers.append(worker) self.workers.append(worker)
@ -49,22 +52,16 @@ class Controller:
def execute_stage( def execute_stage(
self, self,
prompt_tokens: Dict[int, List[int]], input_seq_groups: List[SequenceGroupInputs],
generation_tokens: Dict[int, int],
context_lens: Dict[int, int],
block_tables: Dict[int, List[int]],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, int], blocks_to_copy: Dict[int, List[int]],
) -> None: ) -> None:
# FIXME: Support tensor parallelism. # FIXME: Support tensor parallelism.
assert len(self.workers) == 1 assert len(self.workers) == 1
worker = self.workers[0] worker = self.workers[0]
output = worker.execute_stage( output = worker.execute_stage(
prompt_tokens, input_seq_groups,
generation_tokens,
context_lens,
block_tables,
blocks_to_swap_in, blocks_to_swap_in,
blocks_to_swap_out, blocks_to_swap_out,
blocks_to_copy, blocks_to_copy,

View File

@ -1,9 +1,13 @@
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple
import torch import torch
from cacheflow.models import get_model from cacheflow.models import get_model
from cacheflow.models import set_seed
from cacheflow.models import InputMetadata from cacheflow.models import InputMetadata
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
from cacheflow.worker.cache_engine import CacheEngine from cacheflow.worker.cache_engine import CacheEngine
@ -18,6 +22,7 @@ class Worker:
num_gpu_blocks: int, num_gpu_blocks: int,
num_cpu_blocks: int, num_cpu_blocks: int,
dtype: str, dtype: str,
seed: int,
) -> None: ) -> None:
self.worker_id = worker_id self.worker_id = worker_id
self.gpu_id = gpu_id self.gpu_id = gpu_id
@ -33,6 +38,11 @@ class Worker:
self.head_size = self.model.config.hidden_size // self.num_heads self.head_size = self.model.config.hidden_size // self.num_heads
self.dtype = self.model.dtype self.dtype = self.model.dtype
# Set the seed.
# We set the seed after initializing the model to ensure that
# the random state is not affected by the model initialization.
set_seed(seed)
self.cache_engine = CacheEngine( self.cache_engine = CacheEngine(
worker_id=worker_id, worker_id=worker_id,
gpu_id=gpu_id, gpu_id=gpu_id,
@ -49,55 +59,81 @@ class Worker:
def prepare_inputs( def prepare_inputs(
self, self,
prompt_tokens: Dict[int, List[int]], # Seq id -> List of input token ids. input_seq_groups: List[SequenceGroupInputs],
generation_tokens: Dict[int, int], # Seq id -> Input token id.
context_lens: Dict[int, int], # Seq id -> Number of tokens participating in attention.
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]: ) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
# TODO(woosuk): Support interactive generation. seq_groups: List[Tuple[List[int], SamplingParams]] = []
# Add the prompt tokens. seq_logprobs: Dict[int, float] = {}
prompt_lens: List[int] = [] sampling_params: Dict[int, SamplingParams] = {}
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
prompt_seq_ids = sorted(prompt_tokens.keys()) # Add prompt tokens.
for seq_id in prompt_seq_ids: prompt_lens: List[int] = []
prompt_len = len(prompt_tokens[seq_id]) for input_seq_group in input_seq_groups:
if not input_seq_group.is_prompt:
continue
seq_ids = list(input_seq_group.input_tokens.keys())
sampling_params = input_seq_group.sampling_params
seq_groups.append((seq_ids, sampling_params))
seq_logprobs.update(input_seq_group.seq_logprobs)
# Use any sequence in the group.
seq_id = seq_ids[0]
prompt_tokens = input_seq_group.input_tokens[seq_id]
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens[seq_id]) input_tokens.extend(prompt_tokens)
input_positions.extend(range(len(prompt_tokens[seq_id]))) # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(range(len(prompt_tokens)))
block_table = block_tables[seq_id] # Compute the slot mapping.
block_table = input_seq_group.block_tables[seq_id]
for i in range(prompt_len): for i in range(prompt_len):
block_number = block_table[i // self.block_size] block_number = block_table[i // self.block_size]
block_offset = i % self.block_size block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
# Add the generation tokens. # Add generation tokens.
max_context_len = 0 max_context_len = 0
max_num_blocks_per_seq = 0 max_num_blocks_per_seq = 0
context_lens: List[int] = []
generation_block_tables: List[List[int]] = [] generation_block_tables: List[List[int]] = []
for input_seq_group in input_seq_groups:
if input_seq_group.is_prompt:
continue
generation_seq_ids = sorted(generation_tokens.keys()) seq_ids = list(input_seq_group.input_tokens.keys())
for seq_id in generation_seq_ids: sampling_params = input_seq_group.sampling_params
input_tokens.append(generation_tokens[seq_id]) seq_groups.append((seq_ids, sampling_params))
position_id = context_lens[seq_id] - 1 seq_logprobs.update(input_seq_group.seq_logprobs)
input_positions.append(position_id)
block_table = block_tables[seq_id] for seq_id in seq_ids:
generation_block_tables.append(block_table) assert len(input_seq_group.input_tokens[seq_id]) == 1
generation_token = input_seq_group.input_tokens[seq_id][0]
input_tokens.append(generation_token)
max_context_len = max(max_context_len, context_lens[seq_id]) position = input_seq_group.context_len - 1
max_num_blocks_per_seq = max( input_positions.append(position)
max_num_blocks_per_seq, len(block_table))
block_number = block_table[position_id // self.block_size] block_table = input_seq_group.block_tables[seq_id]
block_offset = position_id % self.block_size generation_block_tables.append(block_table)
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) max_context_len = max(
max_context_len, input_seq_group.context_len)
max_num_blocks_per_seq = max(
max_num_blocks_per_seq, len(block_table))
context_lens.append(input_seq_group.context_len)
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# Optimization: Pad the input length to be a multiple of 8. # Optimization: Pad the input length to be a multiple of 8.
# This is required for utilizing the Tensor Cores in NVIDIA GPUs. # This is required for utilizing the Tensor Cores in NVIDIA GPUs.
@ -112,8 +148,7 @@ class Worker:
slot_mapping_tensor = torch.tensor( slot_mapping_tensor = torch.tensor(
slot_mapping, dtype=torch.int, device=self.device) slot_mapping, dtype=torch.int, device=self.device)
context_lens_tensor = torch.tensor( context_lens_tensor = torch.tensor(
[context_lens[seq_id] for seq_id in generation_seq_ids], context_lens, dtype=torch.int, device=self.device)
dtype=torch.int, device=self.device)
padded_block_tables = [ padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq) _pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables] for block_table in generation_block_tables]
@ -121,7 +156,8 @@ class Worker:
padded_block_tables, dtype=torch.int, device=self.device) padded_block_tables, dtype=torch.int, device=self.device)
input_metadata = InputMetadata( input_metadata = InputMetadata(
seq_ids=prompt_seq_ids + generation_seq_ids, seq_groups=seq_groups,
seq_logprobs=seq_logprobs,
prompt_lens=prompt_lens, prompt_lens=prompt_lens,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
@ -133,14 +169,11 @@ class Worker:
@torch.inference_mode() @torch.inference_mode()
def execute_stage( def execute_stage(
self, self,
prompt_tokens: Dict[int, List[int]], # Seq id -> List of input token ids. input_seq_groups: List[SequenceGroupInputs],
generation_tokens: Dict[int, int], # Seq id -> Input token id.
context_lens: Dict[int, int], # Seq id -> Number of tokens participating in attention.
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, int], blocks_to_copy: Dict[int, List[int]],
) -> Union[torch.Tensor, Dict[int, Tuple[int, int]]]: ) -> Dict[int, SequenceOutputs]:
# Issue cache operations. # Issue cache operations.
command_issued = False command_issued = False
if blocks_to_swap_in: if blocks_to_swap_in:
@ -160,7 +193,7 @@ class Worker:
# Prepare input tensors. # Prepare input tensors.
input_tokens, input_positions, input_metadata = self.prepare_inputs( input_tokens, input_positions, input_metadata = self.prepare_inputs(
prompt_tokens, generation_tokens, context_lens, block_tables) input_seq_groups)
# Execute the model. # Execute the model.
output = self.model( output = self.model(

View File

@ -1,9 +1,17 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <map>
#include <vector>
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
void copy_blocks( void copy_blocks(
torch::Tensor& src, torch::Tensor& src,
torch::Tensor& dst, torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping); const std::map<int64_t, std::vector<int64_t>>& block_mapping);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, torch::Tensor& key,
@ -14,7 +22,11 @@ void reshape_and_cache(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"copy_cache_blocks", "swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
m.def(
"copy_blocks",
&copy_blocks, &copy_blocks,
"Copy the cache blocks from src to dst"); "Copy the cache blocks from src to dst");
m.def( m.def(

View File

@ -5,8 +5,9 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <map> #include <map>
#include <vector>
void copy_blocks( void swap_blocks(
torch::Tensor& src, torch::Tensor& src,
torch::Tensor& dst, torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) { const std::map<int64_t, int64_t>& block_mapping) {
@ -43,6 +44,35 @@ void copy_blocks(
} }
} }
void copy_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
assert(src_device.is_cuda() && dst_device.is_cuda());
cudaMemcpyKind memcpy_type = cudaMemcpyDeviceToDevice;
void *src_ptr = src.data_ptr();
void *dst_ptr = dst.data_ptr();
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync(
dst_ptr + dst_offset,
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
}
}
}
template<typename scalar_t> template<typename scalar_t>
__global__ void reshape_and_cache_kernel( __global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]

View File

@ -15,6 +15,8 @@ parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of
parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)') parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported. # NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
args = parser.parse_args() args = parser.parse_args()
@ -30,6 +32,7 @@ def main():
num_gpu_blocks=args.num_gpu_blocks, num_gpu_blocks=args.num_gpu_blocks,
num_cpu_blocks=args.num_cpu_blocks, num_cpu_blocks=args.num_cpu_blocks,
dtype=args.dtype, dtype=args.dtype,
seed=args.seed,
) )
controllers.append(controller) controllers.append(controller)
@ -52,18 +55,18 @@ def main():
controllers[i].set_next(controllers[i + 1]) controllers[i].set_next(controllers[i + 1])
controllers[-1].set_next(scheduler) controllers[-1].set_next(scheduler)
# Test the following inputs.
test_inputs = [ test_inputs = [
'Ion Stoica is a', ('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),
'UC Berkeley is', ('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}),
'The future of cloud computing is', ('The future of cloud computing is', {}), # Use default parameters.
] ]
# FIXME
while True: while True:
if test_inputs: if test_inputs:
frontend.query(test_inputs.pop()) text, sampling_params = test_inputs.pop(0)
frontend.query(text, **sampling_params)
scheduler.step() scheduler.step()
if not scheduler.pending and not scheduler.running: if not (scheduler.pending or scheduler.running or test_inputs):
break break