Support beam search & parallel generation (#7)
This commit is contained in:
parent
04e5acc08e
commit
1a7eb7da61
@ -35,6 +35,10 @@ class LogicalTokenBlock:
|
||||
def get_token_ids(self) -> List[int]:
|
||||
return self.token_ids[:self.num_tokens]
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
assert self.num_tokens > 0
|
||||
return self.token_ids[self.num_tokens - 1]
|
||||
|
||||
|
||||
class PhysicalTokenBlock:
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@ -25,12 +25,35 @@ class Frontend:
|
||||
def query(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
n: int = 1,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
use_beam_search: bool = False,
|
||||
stop_token_ids: Set[int] = set(),
|
||||
max_num_steps: int = 16, # From OpenAI API.
|
||||
num_logprobs: int = 0,
|
||||
context_window_size: Optional[int] = None,
|
||||
) -> None:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
token_ids: List[int] = self.tokenizer.encode(prompt)
|
||||
# Stop when we see an EOS token.
|
||||
stop_token_ids.add(self.tokenizer.eos_token_id)
|
||||
sampling_params = SamplingParams(
|
||||
n=n,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
use_beam_search=use_beam_search,
|
||||
stop_token_ids=stop_token_ids,
|
||||
max_num_steps=max_num_steps,
|
||||
num_logprobs=num_logprobs,
|
||||
context_window_size=context_window_size,
|
||||
)
|
||||
token_ids = self.tokenizer.encode(prompt)
|
||||
self._add_query(token_ids, sampling_params)
|
||||
|
||||
def _add_query(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
sampling_params: SamplingParams,
|
||||
) -> None:
|
||||
seqs: List[Sequence] = []
|
||||
for _ in range(sampling_params.n):
|
||||
seq_id = next(self.seq_counter)
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
from cacheflow.master.block_manager import BlockSpaceManager
|
||||
from cacheflow.master.frontend import Frontend
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import Sequence
|
||||
from cacheflow.sequence import SequenceGroup
|
||||
from cacheflow.sequence import SequenceGroupInputs
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from cacheflow.sequence import SequenceStatus
|
||||
|
||||
_MAX_NUM_BATCHED_TOKENS = 2048
|
||||
@ -66,7 +68,7 @@ class Scheduler:
|
||||
def _append(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_copy: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
for seq in seq_group.seqs:
|
||||
if seq.status == SequenceStatus.FINISHED:
|
||||
@ -74,7 +76,10 @@ class Scheduler:
|
||||
ret = self.block_manager.append(seq)
|
||||
if ret is not None:
|
||||
src_block, dst_block = ret
|
||||
blocks_to_copy[src_block] = dst_block
|
||||
if src_block in blocks_to_copy:
|
||||
blocks_to_copy[src_block].append(dst_block)
|
||||
else:
|
||||
blocks_to_copy[src_block] = [dst_block]
|
||||
|
||||
def _swap_in(
|
||||
self,
|
||||
@ -83,9 +88,8 @@ class Scheduler:
|
||||
) -> None:
|
||||
mapping = self.block_manager.swap_in(seq_group)
|
||||
blocks_to_swap_in.update(mapping)
|
||||
for seq in seq_group.seqs:
|
||||
if seq.status == SequenceStatus.SWAPPED:
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
self.running.append(seq_group)
|
||||
|
||||
def _swap_out(
|
||||
@ -96,16 +100,15 @@ class Scheduler:
|
||||
assert self.block_manager.can_swap_out(seq_group)
|
||||
mapping = self.block_manager.swap_out(seq_group)
|
||||
blocks_to_swap_out.update(mapping)
|
||||
for seq in seq_group.seqs:
|
||||
if seq.status == SequenceStatus.RUNNING:
|
||||
seq.status = SequenceStatus.SWAPPED
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq.status = SequenceStatus.SWAPPED
|
||||
self.swapped.append(seq_group)
|
||||
|
||||
def step(self) -> None:
|
||||
# Blocks that need to be swaped or copied before model execution.
|
||||
blocks_to_swap_in: Dict[int, int] = {}
|
||||
blocks_to_swap_out: Dict[int, int] = {}
|
||||
blocks_to_copy: Dict[int, int] = {}
|
||||
blocks_to_copy: Dict[int, List[int]] = {}
|
||||
|
||||
# 1. Reserve new slots for the running sequences.
|
||||
# NOTE: Here we implicitly assume FCFS scheduling.
|
||||
@ -143,6 +146,10 @@ class Scheduler:
|
||||
# All swapped sequences are swapped in.
|
||||
self.swapped.clear()
|
||||
|
||||
# Ensure that swap-in and swap-out never happen at the same timestep.
|
||||
if blocks_to_swap_in:
|
||||
assert not blocks_to_swap_out
|
||||
|
||||
num_batched_tokens = sum(
|
||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq_group in self.running
|
||||
@ -152,7 +159,6 @@ class Scheduler:
|
||||
# NOTE: Here we implicitly assume FCFS scheduling.
|
||||
# TODO(woosuk): Add a batching policy to control the batch size.
|
||||
if not self.swapped:
|
||||
# FIXME(woosuk): Acquire a lock to protect pending.
|
||||
self._fetch_inputs()
|
||||
for i, seq_group in enumerate(self.pending):
|
||||
num_prompt_tokens = seq_group.seqs[0].get_len()
|
||||
@ -168,39 +174,45 @@ class Scheduler:
|
||||
else:
|
||||
self.pending.clear()
|
||||
|
||||
# Ensure that swap-in and swap-out never happen at the same timestep.
|
||||
if blocks_to_swap_in:
|
||||
assert not blocks_to_swap_out
|
||||
|
||||
# 4. Create input data structures.
|
||||
prompt_tokens: Dict[int, List[int]] = {}
|
||||
generation_tokens: Dict[int, int] = {}
|
||||
context_lens: Dict[int, int] = {}
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
input_seq_groups: List[SequenceGroupInputs] = []
|
||||
for seq_group in self.running:
|
||||
group_id = seq_group.group_id
|
||||
num_steps = self.num_steps[group_id]
|
||||
|
||||
# NOTE(woosuk): We assume that the number of steps is 0
|
||||
# for the prompt sequences.
|
||||
is_prompt = num_steps == 0
|
||||
for seq in seq_group.seqs:
|
||||
if seq.status != SequenceStatus.RUNNING:
|
||||
continue
|
||||
|
||||
input_tokens: Dict[int, List[int]] = {}
|
||||
seq_logprobs: Dict[int, float] = {}
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq_id = seq.seq_id
|
||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
||||
if is_prompt:
|
||||
prompt_tokens[seq_id] = seq.get_token_ids()
|
||||
input_tokens[seq_id] = seq.get_token_ids()
|
||||
else:
|
||||
generation_tokens[seq_id] = seq.get_token_ids()[-1]
|
||||
context_lens[seq_id] = seq.get_len()
|
||||
input_tokens[seq_id] = [seq.get_last_token_id()]
|
||||
seq_logprobs[seq_id] = seq.cumulative_logprobs
|
||||
# NOTE(woosuk): Sequences in the same group have the same
|
||||
# sequence length
|
||||
seq_len = seq.get_len()
|
||||
|
||||
input_seq_group = SequenceGroupInputs(
|
||||
group_id=group_id,
|
||||
is_prompt=is_prompt,
|
||||
input_tokens=input_tokens,
|
||||
context_len=seq_len,
|
||||
seq_logprobs=seq_logprobs,
|
||||
sampling_params=self.sampling_params[group_id],
|
||||
block_tables=block_tables,
|
||||
)
|
||||
input_seq_groups.append(input_seq_group)
|
||||
|
||||
# 5. Execute the first stage of the pipeline.
|
||||
self.controllers[0].execute_stage(
|
||||
prompt_tokens,
|
||||
generation_tokens,
|
||||
context_lens,
|
||||
block_tables,
|
||||
input_seq_groups,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
@ -208,7 +220,7 @@ class Scheduler:
|
||||
|
||||
def post_step(
|
||||
self,
|
||||
next_tokens: Dict[int, Tuple[int, int]],
|
||||
seq_outputs: Dict[int, SequenceOutputs],
|
||||
) -> None:
|
||||
# Update the running sequences and free blocks.
|
||||
for seq_group in self.running:
|
||||
@ -216,25 +228,32 @@ class Scheduler:
|
||||
self.num_steps[group_id] += 1
|
||||
stop_token_ids = self.sampling_params[group_id].stop_token_ids
|
||||
|
||||
# Process beam search results before processing the next tokens.
|
||||
for seq in seq_group.seqs:
|
||||
if seq.status == SequenceStatus.FINISHED:
|
||||
continue
|
||||
|
||||
parent_seq_id, next_token = next_tokens[seq.seq_id]
|
||||
if seq.seq_id != parent_seq_id:
|
||||
output = seq_outputs[seq.seq_id]
|
||||
if seq.seq_id != output.parent_seq_id:
|
||||
# The sequence is a fork of the parent sequence (beam search).
|
||||
# Free the current sequence.
|
||||
self.block_manager.free(seq)
|
||||
# Fork the parent sequence.
|
||||
parent_seq = seq_group.find(parent_seq_id)
|
||||
seq.logical_token_blocks = parent_seq.logical_token_blocks.copy()
|
||||
parent_seq = seq_group.find(output.parent_seq_id)
|
||||
parent_seq.fork(seq)
|
||||
self.block_manager.fork(parent_seq, seq)
|
||||
|
||||
# Process the next tokens.
|
||||
for seq in seq_group.seqs:
|
||||
if seq.status == SequenceStatus.FINISHED:
|
||||
continue
|
||||
|
||||
# Append a new token to the sequence.
|
||||
seq.append([next_token])
|
||||
output = seq_outputs[seq.seq_id]
|
||||
seq.append(output.output_token, output.logprobs)
|
||||
|
||||
# Check if the sequence has generated a stop token.
|
||||
if next_token in stop_token_ids:
|
||||
if output.output_token in stop_token_ids:
|
||||
self._free_seq(seq)
|
||||
continue
|
||||
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from cacheflow.models.input_metadata import InputMetadata
|
||||
from cacheflow.models.model_utils import get_model
|
||||
from cacheflow.models.model_utils import set_seed
|
||||
|
||||
|
||||
__all__ = [
|
||||
'get_model',
|
||||
'InputMetadata',
|
||||
'get_model',
|
||||
'set_seed'
|
||||
]
|
||||
|
||||
@ -1,21 +1,24 @@
|
||||
from typing import List
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class InputMetadata:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_ids: List[int],
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
||||
prompt_lens: List[int],
|
||||
slot_mapping: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
# FIXME: Rename
|
||||
max_context_len: int,
|
||||
block_tables: torch.Tensor,
|
||||
) -> None:
|
||||
self.seq_ids = seq_ids
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_logprobs = seq_logprobs
|
||||
self.prompt_lens = prompt_lens
|
||||
self.slot_mapping = slot_mapping
|
||||
self.context_lens = context_lens
|
||||
@ -23,19 +26,20 @@ class InputMetadata:
|
||||
self.block_tables = block_tables
|
||||
|
||||
self.num_prompts = len(prompt_lens)
|
||||
self.num_prompt_tokens = sum(prompt_lens)
|
||||
self.num_generation_tokens = context_lens.shape[0]
|
||||
self.num_valid_tokens = slot_mapping.shape[0]
|
||||
if block_tables.numel() > 0:
|
||||
self.max_num_blocks_per_seq = block_tables.shape[1]
|
||||
else:
|
||||
self.max_num_blocks_per_seq = 0
|
||||
assert self.num_generation_tokens == block_tables.shape[0]
|
||||
assert self.num_prompts + self.num_generation_tokens == len(seq_ids)
|
||||
assert block_tables.shape[0] == self.num_generation_tokens
|
||||
assert context_lens.shape[0] == self.num_generation_tokens
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'InputMetadata('
|
||||
f'seq_ids={self.seq_ids}, '
|
||||
f'num_prompts={self.num_prompts}, '
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'num_valid_tokens={self.num_valid_tokens}, '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -30,3 +32,11 @@ def get_model(
|
||||
model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype)
|
||||
return model.eval()
|
||||
raise ValueError(f'Invalid model name: {model_name}')
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
@ -9,6 +9,7 @@ from transformers import PreTrainedModel
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import OPTCacheFlowAttention
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -261,7 +262,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, Tuple[int, int]]:
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
|
||||
@ -4,6 +4,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
@ -16,27 +18,266 @@ class Sampler(nn.Module):
|
||||
embedding: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> Dict[int, Tuple[int, int]]:
|
||||
# Get the hidden states of the last tokens.
|
||||
start_idx = 0
|
||||
last_token_indicies: List[int] = []
|
||||
for prompt_len in input_metadata.prompt_lens:
|
||||
last_token_indicies.append(start_idx + prompt_len - 1)
|
||||
start_idx += prompt_len
|
||||
last_token_indicies.extend(
|
||||
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
||||
hidden_states = hidden_states[last_token_indicies]
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
# Get the hidden states that we use for sampling.
|
||||
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
|
||||
# Sample the next tokens.
|
||||
# TODO(woosuk): Implement other sampling methods.
|
||||
next_token_ids = torch.argmax(logits, dim=-1)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
assert len(temperatures) == logits.shape[0]
|
||||
if any(t != 1.0 for t in temperatures):
|
||||
t = torch.tensor(
|
||||
temperatures, dtype=logits.dtype, device=logits.device)
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits.div_(t.unsqueeze(dim=1))
|
||||
|
||||
# Return the next tokens.
|
||||
next_tokens: Dict[int, Tuple[int, int]] = {}
|
||||
for seq_id, token_id in zip(input_metadata.seq_ids, next_token_ids):
|
||||
next_tokens[seq_id] = (seq_id, token_id)
|
||||
return next_tokens
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities (before applying top-p).
|
||||
logprobs = torch.log(probs)
|
||||
|
||||
# Apply top-p truncation.
|
||||
top_ps = _get_top_ps(input_metadata)
|
||||
assert len(top_ps) == probs.shape[0]
|
||||
if any(p < 1.0 for p in top_ps):
|
||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
||||
probs = _apply_top_p(probs, p)
|
||||
|
||||
# Sample the next tokens.
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
|
||||
|
||||
def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
start_idx = 0
|
||||
last_token_indicies: List[int] = []
|
||||
for prompt_len in input_metadata.prompt_lens:
|
||||
last_token_indicies.append(start_idx + prompt_len - 1)
|
||||
start_idx += prompt_len
|
||||
last_token_indicies.extend(
|
||||
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
||||
return hidden_states[last_token_indicies]
|
||||
|
||||
|
||||
def _get_temperatures(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[float]:
|
||||
# Collect the temperatures for the logits.
|
||||
temperatures: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
temperature = sampling_params.temperature
|
||||
if temperature == 0.0:
|
||||
# NOTE: Zero temperature means deterministic sampling
|
||||
# (i.e., greedy sampling or beam search).
|
||||
# Set the temperature to 1 to avoid division by zero.
|
||||
temperature = 1.0
|
||||
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
temperatures.append(temperature)
|
||||
else:
|
||||
# A generation token.
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
return temperatures
|
||||
|
||||
|
||||
def _get_top_ps(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[float]:
|
||||
top_ps: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
top_ps.append(sampling_params.top_p)
|
||||
else:
|
||||
# A generation token.
|
||||
top_ps += [sampling_params.top_p] * len(seq_ids)
|
||||
return top_ps
|
||||
|
||||
|
||||
def _apply_top_p(
|
||||
probs: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# TODO(woosuk): Optimize.
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
probs = torch.gather(
|
||||
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
|
||||
return probs
|
||||
|
||||
|
||||
def _get_topk_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
) -> Dict[int, float]:
|
||||
if num_logprobs == 0:
|
||||
return {}
|
||||
|
||||
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
|
||||
if num_logprobs == 1:
|
||||
topk_logprobs = [topk_logprobs.item()]
|
||||
topk_ids = [topk_ids.item()]
|
||||
else:
|
||||
topk_logprobs = topk_logprobs.tolist()
|
||||
topk_ids = topk_ids.tolist()
|
||||
|
||||
token_to_logprob: Dict[int, float] = {}
|
||||
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
||||
token_to_logprob[token_id] = logprob
|
||||
return token_to_logprob
|
||||
|
||||
|
||||
def _sample_from_prompt(
|
||||
prob: torch.Tensor,
|
||||
sampling_params: SamplingParams,
|
||||
) -> List[int]:
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
beam_width = sampling_params.n
|
||||
_, next_token_ids = torch.topk(prob, beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
elif sampling_params.temperature == 0.0:
|
||||
# Greedy sampling.
|
||||
assert sampling_params.n == 1
|
||||
next_token_id = torch.argmax(prob)
|
||||
next_token_ids = [next_token_id.item()]
|
||||
else:
|
||||
# Neucleus sampling.
|
||||
# Sample n tokens for the prompt.
|
||||
n = sampling_params.n
|
||||
next_token_ids = torch.multinomial(
|
||||
prob, num_samples=n, replacement=True)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return next_token_ids
|
||||
|
||||
|
||||
def _sample_from_generation_tokens(
|
||||
seq_ids: List[int],
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
seq_logprobs: List[float],
|
||||
sampling_params: SamplingParams,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
# NOTE(woosuk): sampling_params.n can be greater than
|
||||
# len(seq_ids) because some sequences in the group might have
|
||||
# been already terminated.
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
# Add cumulative logprobs for the sequences in the group.
|
||||
seq_logprobs = torch.tensor(
|
||||
seq_logprobs, dtype=torch.float, device=logprobs.device)
|
||||
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
|
||||
|
||||
vocab_size = logprobs.size(-1)
|
||||
beam_width = len(seq_ids)
|
||||
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
|
||||
seq_idx = torch.div(topk_ids, vocab_size, rounding_mode='floor').tolist()
|
||||
beam_seq_ids = [seq_ids[i] for i in seq_idx]
|
||||
token_ids = (topk_ids % vocab_size).tolist()
|
||||
|
||||
beam_outputs: Dict[int, Tuple[int, int]] = {}
|
||||
outstanding_beams: List[Tuple[int, int]] = []
|
||||
# If a beam survives, continue with it.
|
||||
for seq_id, token_id in zip(beam_seq_ids, token_ids):
|
||||
if seq_id not in beam_outputs:
|
||||
beam_outputs[seq_id] = (seq_id, token_id)
|
||||
else:
|
||||
outstanding_beams.append((seq_id, token_id))
|
||||
|
||||
# If a beam is discarded, fork another beam.
|
||||
for seq_id in seq_ids:
|
||||
if seq_id not in beam_outputs:
|
||||
beam_outputs[seq_id] = outstanding_beams.pop()
|
||||
assert not outstanding_beams
|
||||
|
||||
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
|
||||
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
|
||||
elif sampling_params.temperature == 0.0:
|
||||
# Greedy sampling.
|
||||
assert len(seq_ids) == 1
|
||||
next_token_id = torch.argmax(probs, dim=-1)
|
||||
next_token_ids = [next_token_id.item()]
|
||||
parent_seq_ids = seq_ids
|
||||
else:
|
||||
# Neucleus sampling.
|
||||
# Sample 1 token for each sequence in the group.
|
||||
next_token_ids = torch.multinomial(
|
||||
probs, num_samples=1, replacement=True)
|
||||
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
|
||||
parent_seq_ids = seq_ids
|
||||
return parent_seq_ids, next_token_ids
|
||||
|
||||
|
||||
def _sample(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
seq_outputs: Dict[int, SequenceOutputs] = {}
|
||||
|
||||
# TODO(woosuk): Optimize.
|
||||
idx = 0
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# Generate the next tokens for a prompt input.
|
||||
assert len(seq_ids) == sampling_params.n
|
||||
prob = probs[idx]
|
||||
logprob = logprobs[idx]
|
||||
idx += 1
|
||||
|
||||
# Sample the next tokens.
|
||||
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs = _get_topk_logprobs(
|
||||
logprob, sampling_params.num_logprobs)
|
||||
|
||||
# Build the output.
|
||||
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
|
||||
output_logprobs = next_logprobs.copy()
|
||||
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
||||
seq_outputs[seq_id] = SequenceOutputs(
|
||||
seq_id, seq_id, next_token_id, output_logprobs)
|
||||
else:
|
||||
# Generate the next tokens for generation tokens.
|
||||
prob = probs[idx:idx + len(seq_ids)]
|
||||
logprob = logprobs[idx:idx + len(seq_ids)]
|
||||
idx += len(seq_ids)
|
||||
|
||||
# Sample the next tokens.
|
||||
seq_logprobs = [
|
||||
input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
|
||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
||||
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs: Dict[int, Dict[int, float]] = {}
|
||||
for i, seq_id in enumerate(seq_ids):
|
||||
next_logprobs[seq_id] = _get_topk_logprobs(
|
||||
logprob[i], sampling_params.num_logprobs)
|
||||
|
||||
# Build the output.
|
||||
for seq_id, parent_seq_id, next_token_id in zip(
|
||||
seq_ids, parent_seq_ids, next_token_ids):
|
||||
i = seq_ids.index(parent_seq_id)
|
||||
output_logprobs = next_logprobs[parent_seq_id].copy()
|
||||
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
|
||||
seq_outputs[seq_id] = SequenceOutputs(
|
||||
seq_id,
|
||||
parent_seq_id,
|
||||
next_token_id,
|
||||
output_logprobs,
|
||||
)
|
||||
|
||||
return seq_outputs
|
||||
|
||||
@ -5,27 +5,51 @@ class SamplingParams:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n: int = 1,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
use_beam_search: bool = False,
|
||||
stop_token_ids: Set[int] = [],
|
||||
max_num_steps: int = 16, # From OpenAI API.
|
||||
max_context_len: Optional[int] = None,
|
||||
n: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
use_beam_search: bool,
|
||||
stop_token_ids: Set[int],
|
||||
max_num_steps: int,
|
||||
num_logprobs: int,
|
||||
context_window_size: Optional[int],
|
||||
) -> None:
|
||||
assert n >= 1
|
||||
assert temperature >= 0.0
|
||||
assert 0.0 < top_p <= 1.0
|
||||
if n < 1:
|
||||
raise ValueError(f'n must be at least 1, got {n}.')
|
||||
if temperature < 0.0:
|
||||
raise ValueError(
|
||||
f'temperature must be non-negative, got {temperature}.')
|
||||
if not 0.0 < top_p <= 1.0:
|
||||
raise ValueError(f'top_p must be in (0, 1], got {top_p}.')
|
||||
if max_num_steps < 1:
|
||||
raise ValueError(
|
||||
f'max_num_steps must be at least 1, got {max_num_steps}.')
|
||||
if num_logprobs < 0:
|
||||
raise ValueError(
|
||||
f'num_logprobs must be non-negative, got {num_logprobs}.')
|
||||
if context_window_size is not None and context_window_size < 0:
|
||||
raise ValueError(
|
||||
'context_window_size must be non-negative, '
|
||||
f'got {context_window_size}.')
|
||||
|
||||
if use_beam_search:
|
||||
assert n > 1
|
||||
assert temperature > 0.0
|
||||
assert top_p == 1.0
|
||||
if n == 1:
|
||||
raise ValueError(
|
||||
'n must be greater than 1 when using beam search.')
|
||||
if temperature > 0.0:
|
||||
raise ValueError(
|
||||
'temperature must be 0 when using beam search.')
|
||||
if top_p < 1.0:
|
||||
raise ValueError(
|
||||
'top_p must be 1 when using beam search.')
|
||||
elif temperature == 0.0:
|
||||
# Zero temperature means greedy decoding.
|
||||
assert n == 1
|
||||
assert top_p == 1.0
|
||||
assert max_num_steps >= 1
|
||||
assert max_context_len is None or max_context_len >= 0
|
||||
# Zero temperature means greedy sampling.
|
||||
if n > 1:
|
||||
raise ValueError(
|
||||
'n must be 1 when using greedy sampling.')
|
||||
if top_p < 1.0:
|
||||
raise ValueError(
|
||||
'top_p must be 1 when using greedy sampling.')
|
||||
|
||||
self.n = n
|
||||
self.temperature = temperature
|
||||
@ -33,4 +57,15 @@ class SamplingParams:
|
||||
self.use_beam_search = use_beam_search
|
||||
self.stop_token_ids = stop_token_ids
|
||||
self.max_num_steps = max_num_steps
|
||||
self.max_context_len = max_context_len
|
||||
self.num_logprobs = num_logprobs
|
||||
self.context_window_size = context_window_size
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'SamplingParams(n={self.n}, '
|
||||
f'temperature={self.temperature}, '
|
||||
f'top_p={self.top_p}, '
|
||||
f'use_beam_search={self.use_beam_search}, '
|
||||
f'stop_token_ids={self.stop_token_ids}, '
|
||||
f'max_num_steps={self.max_num_steps}, '
|
||||
f'num_logprobs={self.num_logprobs}, '
|
||||
f'context_window_size={self.context_window_size})')
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import copy
|
||||
import enum
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from cacheflow.block import LogicalTokenBlock
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class SequenceStatus(enum.Enum):
|
||||
@ -24,9 +26,11 @@ class Sequence:
|
||||
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
# Initialize the logical token blocks with the given token ids.
|
||||
self.append(token_ids)
|
||||
self.add(token_ids)
|
||||
|
||||
self.status = SequenceStatus.PENDING
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
self.cumulative_logprobs = 1.0
|
||||
|
||||
def add_block(self) -> None:
|
||||
block = LogicalTokenBlock(
|
||||
@ -35,7 +39,7 @@ class Sequence:
|
||||
)
|
||||
self.logical_token_blocks.append(block)
|
||||
|
||||
def append(self, token_ids: List[int]) -> None:
|
||||
def add(self, token_ids: List[int]) -> None:
|
||||
while token_ids:
|
||||
if not self.logical_token_blocks:
|
||||
self.add_block()
|
||||
@ -49,6 +53,12 @@ class Sequence:
|
||||
last_block.append(token_ids[:num_empty_slots])
|
||||
token_ids = token_ids[num_empty_slots:]
|
||||
|
||||
def append(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
||||
assert token_id in logprobs
|
||||
self.add([token_id])
|
||||
self.output_logprobs.append(logprobs)
|
||||
self.cumulative_logprobs += logprobs[token_id]
|
||||
|
||||
def get_len(self) -> int:
|
||||
return sum(block.num_tokens for block in self.logical_token_blocks)
|
||||
|
||||
@ -58,6 +68,14 @@ class Sequence:
|
||||
token_ids.extend(block.get_token_ids())
|
||||
return token_ids
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
return self.logical_token_blocks[-1].get_last_token_id()
|
||||
|
||||
def fork(self, child_seq: 'Sequence') -> 'Sequence':
|
||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
child_seq.cumulative_logprobs = self.cumulative_logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'Sequence(seq_id={self.seq_id}, '
|
||||
f'status={self.status.name}, '
|
||||
@ -74,11 +92,17 @@ class SequenceGroup:
|
||||
self.group_id = group_id
|
||||
self.seqs = seqs
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
def get_seqs(
|
||||
self,
|
||||
status: Optional[SequenceStatus] = None,
|
||||
) -> List[Sequence]:
|
||||
if status is None:
|
||||
return len(self.seqs)
|
||||
return self.seqs
|
||||
else:
|
||||
return len([seq for seq in self.seqs if seq.status == status])
|
||||
return [seq for seq in self.seqs if seq.status == status]
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
def find(self, seq_id: int) -> Sequence:
|
||||
for seq in self.seqs:
|
||||
@ -92,3 +116,45 @@ class SequenceGroup:
|
||||
def __repr__(self) -> str:
|
||||
return (f'SequenceGroup(group_id={self.group_id}, '
|
||||
f'num_seqs={len(self.seqs)})')
|
||||
|
||||
|
||||
class SequenceGroupInputs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_id: int,
|
||||
is_prompt: bool,
|
||||
input_tokens: Dict[int, List[int]], # Seq id -> token ids.
|
||||
context_len: int,
|
||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
|
||||
) -> None:
|
||||
self.group_id = group_id
|
||||
self.is_prompt = is_prompt
|
||||
self.input_tokens = input_tokens
|
||||
self.context_len = context_len
|
||||
self.seq_logprobs = seq_logprobs
|
||||
self.sampling_params = sampling_params
|
||||
self.block_tables = block_tables
|
||||
|
||||
|
||||
class SequenceOutputs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
parent_seq_id: int,
|
||||
output_token: int,
|
||||
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i).
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.parent_seq_id = parent_seq_id
|
||||
self.output_token = output_token
|
||||
self.logprobs = logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'SequenceOutputs(seq_id={self.seq_id}, '
|
||||
f'parent_seq_id={self.parent_seq_id}, '
|
||||
f'output_token={self.output_token}), '
|
||||
f'logprobs={self.logprobs}')
|
||||
|
||||
@ -97,7 +97,7 @@ class CacheEngine:
|
||||
cpu_cache.append((key_blocks, value_blocks))
|
||||
return cpu_cache
|
||||
|
||||
def _copy_blocks(
|
||||
def _swap(
|
||||
self,
|
||||
src: List[KVCache],
|
||||
dst: List[KVCache],
|
||||
@ -108,19 +108,38 @@ class CacheEngine:
|
||||
src_key_cache, src_value_cache = src[i]
|
||||
dst_key_cache, dst_value_cache = dst[i]
|
||||
# Copy the key blocks.
|
||||
cache_ops.copy_cache_blocks(
|
||||
cache_ops.swap_blocks(
|
||||
src_key_cache, dst_key_cache, src_to_dst)
|
||||
# Copy the value blocks.
|
||||
cache_ops.copy_cache_blocks(
|
||||
cache_ops.swap_blocks(
|
||||
src_value_cache, dst_value_cache, src_to_dst)
|
||||
event = self.events[i]
|
||||
event.record(stream=self.cache_stream)
|
||||
|
||||
def copy(self, src_to_dst: Dict[int, int]) -> None:
|
||||
self._copy_blocks(self.gpu_cache, self.gpu_cache, src_to_dst)
|
||||
|
||||
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
||||
self._copy_blocks(self.cpu_cache, self.gpu_cache, src_to_dst)
|
||||
self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
|
||||
|
||||
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
||||
self._copy_blocks(self.gpu_cache, self.cpu_cache, src_to_dst)
|
||||
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
|
||||
|
||||
def _copy(
|
||||
self,
|
||||
src: List[KVCache],
|
||||
dst: List[KVCache],
|
||||
src_to_dsts: Dict[int, List[int]],
|
||||
) -> None:
|
||||
with torch.cuda.stream(self.cache_stream):
|
||||
for i in range(self.num_layers):
|
||||
src_key_cache, src_value_cache = src[i]
|
||||
dst_key_cache, dst_value_cache = dst[i]
|
||||
# Copy the key blocks.
|
||||
cache_ops.copy_blocks(
|
||||
src_key_cache, dst_key_cache, src_to_dsts)
|
||||
# Copy the value blocks.
|
||||
cache_ops.copy_blocks(
|
||||
src_value_cache, dst_value_cache, src_to_dsts)
|
||||
event = self.events[i]
|
||||
event.record(stream=self.cache_stream)
|
||||
|
||||
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
||||
self._copy(self.gpu_cache, self.gpu_cache, src_to_dsts)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from cacheflow.master.scheduler import Scheduler
|
||||
from cacheflow.sequence import SequenceGroupInputs
|
||||
from cacheflow.worker.worker import Worker
|
||||
|
||||
|
||||
@ -14,7 +15,8 @@ class Controller:
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
dtype: str = 'half',
|
||||
dtype: str,
|
||||
seed: int,
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self.num_workers = num_workers
|
||||
@ -37,6 +39,7 @@ class Controller:
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
|
||||
@ -49,22 +52,16 @@ class Controller:
|
||||
|
||||
def execute_stage(
|
||||
self,
|
||||
prompt_tokens: Dict[int, List[int]],
|
||||
generation_tokens: Dict[int, int],
|
||||
context_lens: Dict[int, int],
|
||||
block_tables: Dict[int, List[int]],
|
||||
input_seq_groups: List[SequenceGroupInputs],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
# FIXME: Support tensor parallelism.
|
||||
assert len(self.workers) == 1
|
||||
worker = self.workers[0]
|
||||
output = worker.execute_stage(
|
||||
prompt_tokens,
|
||||
generation_tokens,
|
||||
context_lens,
|
||||
block_tables,
|
||||
input_seq_groups,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from cacheflow.models import get_model
|
||||
from cacheflow.models import set_seed
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceGroupInputs
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from cacheflow.worker.cache_engine import CacheEngine
|
||||
|
||||
|
||||
@ -18,6 +22,7 @@ class Worker:
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
dtype: str,
|
||||
seed: int,
|
||||
) -> None:
|
||||
self.worker_id = worker_id
|
||||
self.gpu_id = gpu_id
|
||||
@ -33,6 +38,11 @@ class Worker:
|
||||
self.head_size = self.model.config.hidden_size // self.num_heads
|
||||
self.dtype = self.model.dtype
|
||||
|
||||
# Set the seed.
|
||||
# We set the seed after initializing the model to ensure that
|
||||
# the random state is not affected by the model initialization.
|
||||
set_seed(seed)
|
||||
|
||||
self.cache_engine = CacheEngine(
|
||||
worker_id=worker_id,
|
||||
gpu_id=gpu_id,
|
||||
@ -49,55 +59,81 @@ class Worker:
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
prompt_tokens: Dict[int, List[int]], # Seq id -> List of input token ids.
|
||||
generation_tokens: Dict[int, int], # Seq id -> Input token id.
|
||||
context_lens: Dict[int, int], # Seq id -> Number of tokens participating in attention.
|
||||
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
|
||||
input_seq_groups: List[SequenceGroupInputs],
|
||||
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
||||
# TODO(woosuk): Support interactive generation.
|
||||
# Add the prompt tokens.
|
||||
prompt_lens: List[int] = []
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
seq_logprobs: Dict[int, float] = {}
|
||||
sampling_params: Dict[int, SamplingParams] = {}
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
|
||||
prompt_seq_ids = sorted(prompt_tokens.keys())
|
||||
for seq_id in prompt_seq_ids:
|
||||
prompt_len = len(prompt_tokens[seq_id])
|
||||
# Add prompt tokens.
|
||||
prompt_lens: List[int] = []
|
||||
for input_seq_group in input_seq_groups:
|
||||
if not input_seq_group.is_prompt:
|
||||
continue
|
||||
|
||||
seq_ids = list(input_seq_group.input_tokens.keys())
|
||||
sampling_params = input_seq_group.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
seq_logprobs.update(input_seq_group.seq_logprobs)
|
||||
|
||||
# Use any sequence in the group.
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
prompt_tokens = input_seq_group.input_tokens[seq_id]
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
input_tokens.extend(prompt_tokens[seq_id])
|
||||
input_positions.extend(range(len(prompt_tokens[seq_id])))
|
||||
input_tokens.extend(prompt_tokens)
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.extend(range(len(prompt_tokens)))
|
||||
|
||||
block_table = block_tables[seq_id]
|
||||
# Compute the slot mapping.
|
||||
block_table = input_seq_group.block_tables[seq_id]
|
||||
for i in range(prompt_len):
|
||||
block_number = block_table[i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
# Add the generation tokens.
|
||||
# Add generation tokens.
|
||||
max_context_len = 0
|
||||
max_num_blocks_per_seq = 0
|
||||
context_lens: List[int] = []
|
||||
generation_block_tables: List[List[int]] = []
|
||||
for input_seq_group in input_seq_groups:
|
||||
if input_seq_group.is_prompt:
|
||||
continue
|
||||
|
||||
generation_seq_ids = sorted(generation_tokens.keys())
|
||||
for seq_id in generation_seq_ids:
|
||||
input_tokens.append(generation_tokens[seq_id])
|
||||
position_id = context_lens[seq_id] - 1
|
||||
input_positions.append(position_id)
|
||||
seq_ids = list(input_seq_group.input_tokens.keys())
|
||||
sampling_params = input_seq_group.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
seq_logprobs.update(input_seq_group.seq_logprobs)
|
||||
|
||||
block_table = block_tables[seq_id]
|
||||
generation_block_tables.append(block_table)
|
||||
for seq_id in seq_ids:
|
||||
assert len(input_seq_group.input_tokens[seq_id]) == 1
|
||||
generation_token = input_seq_group.input_tokens[seq_id][0]
|
||||
input_tokens.append(generation_token)
|
||||
|
||||
max_context_len = max(max_context_len, context_lens[seq_id])
|
||||
max_num_blocks_per_seq = max(
|
||||
max_num_blocks_per_seq, len(block_table))
|
||||
position = input_seq_group.context_len - 1
|
||||
input_positions.append(position)
|
||||
|
||||
block_number = block_table[position_id // self.block_size]
|
||||
block_offset = position_id % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
block_table = input_seq_group.block_tables[seq_id]
|
||||
generation_block_tables.append(block_table)
|
||||
|
||||
max_context_len = max(
|
||||
max_context_len, input_seq_group.context_len)
|
||||
max_num_blocks_per_seq = max(
|
||||
max_num_blocks_per_seq, len(block_table))
|
||||
context_lens.append(input_seq_group.context_len)
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
# Optimization: Pad the input length to be a multiple of 8.
|
||||
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
|
||||
@ -112,8 +148,7 @@ class Worker:
|
||||
slot_mapping_tensor = torch.tensor(
|
||||
slot_mapping, dtype=torch.int, device=self.device)
|
||||
context_lens_tensor = torch.tensor(
|
||||
[context_lens[seq_id] for seq_id in generation_seq_ids],
|
||||
dtype=torch.int, device=self.device)
|
||||
context_lens, dtype=torch.int, device=self.device)
|
||||
padded_block_tables = [
|
||||
_pad_to_max(block_table, max_num_blocks_per_seq)
|
||||
for block_table in generation_block_tables]
|
||||
@ -121,7 +156,8 @@ class Worker:
|
||||
padded_block_tables, dtype=torch.int, device=self.device)
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
seq_ids=prompt_seq_ids + generation_seq_ids,
|
||||
seq_groups=seq_groups,
|
||||
seq_logprobs=seq_logprobs,
|
||||
prompt_lens=prompt_lens,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
context_lens=context_lens_tensor,
|
||||
@ -133,14 +169,11 @@ class Worker:
|
||||
@torch.inference_mode()
|
||||
def execute_stage(
|
||||
self,
|
||||
prompt_tokens: Dict[int, List[int]], # Seq id -> List of input token ids.
|
||||
generation_tokens: Dict[int, int], # Seq id -> Input token id.
|
||||
context_lens: Dict[int, int], # Seq id -> Number of tokens participating in attention.
|
||||
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
|
||||
input_seq_groups: List[SequenceGroupInputs],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, int],
|
||||
) -> Union[torch.Tensor, Dict[int, Tuple[int, int]]]:
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
# Issue cache operations.
|
||||
command_issued = False
|
||||
if blocks_to_swap_in:
|
||||
@ -160,7 +193,7 @@ class Worker:
|
||||
|
||||
# Prepare input tensors.
|
||||
input_tokens, input_positions, input_metadata = self.prepare_inputs(
|
||||
prompt_tokens, generation_tokens, context_lens, block_tables)
|
||||
input_seq_groups)
|
||||
|
||||
# Execute the model.
|
||||
output = self.model(
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
const std::map<int64_t, int64_t>& block_mapping);
|
||||
|
||||
void copy_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
const std::map<int64_t, int64_t>& block_mapping);
|
||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key,
|
||||
@ -14,7 +22,11 @@ void reshape_and_cache(
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"copy_cache_blocks",
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
m.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
m.def(
|
||||
|
||||
@ -5,8 +5,9 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
void copy_blocks(
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
const std::map<int64_t, int64_t>& block_mapping) {
|
||||
@ -43,6 +44,35 @@ void copy_blocks(
|
||||
}
|
||||
}
|
||||
|
||||
void copy_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
|
||||
torch::Device src_device = src.device();
|
||||
torch::Device dst_device = dst.device();
|
||||
assert(src_device.is_cuda() && dst_device.is_cuda());
|
||||
cudaMemcpyKind memcpy_type = cudaMemcpyDeviceToDevice;
|
||||
|
||||
void *src_ptr = src.data_ptr();
|
||||
void *dst_ptr = dst.data_ptr();
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
for (const auto& pair : block_mapping) {
|
||||
int64_t src_block_number = pair.first;
|
||||
for (int64_t dst_block_number : pair.second) {
|
||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||
cudaMemcpyAsync(
|
||||
dst_ptr + dst_offset,
|
||||
src_ptr + src_offset,
|
||||
block_size_in_bytes,
|
||||
memcpy_type,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void reshape_and_cache_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
|
||||
17
server.py
17
server.py
@ -15,6 +15,8 @@ parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of
|
||||
parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)')
|
||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@ -30,6 +32,7 @@ def main():
|
||||
num_gpu_blocks=args.num_gpu_blocks,
|
||||
num_cpu_blocks=args.num_cpu_blocks,
|
||||
dtype=args.dtype,
|
||||
seed=args.seed,
|
||||
)
|
||||
controllers.append(controller)
|
||||
|
||||
@ -52,18 +55,18 @@ def main():
|
||||
controllers[i].set_next(controllers[i + 1])
|
||||
controllers[-1].set_next(scheduler)
|
||||
|
||||
# Test the following inputs.
|
||||
test_inputs = [
|
||||
'Ion Stoica is a',
|
||||
'UC Berkeley is',
|
||||
'The future of cloud computing is',
|
||||
('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),
|
||||
('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}),
|
||||
('The future of cloud computing is', {}), # Use default parameters.
|
||||
]
|
||||
|
||||
# FIXME
|
||||
while True:
|
||||
if test_inputs:
|
||||
frontend.query(test_inputs.pop())
|
||||
text, sampling_params = test_inputs.pop(0)
|
||||
frontend.query(text, **sampling_params)
|
||||
scheduler.step()
|
||||
if not scheduler.pending and not scheduler.running:
|
||||
if not (scheduler.pending or scheduler.running or test_inputs):
|
||||
break
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user