Rename variables and methods (#91)
This commit is contained in:
parent
ce26e57fd3
commit
8d66a7b6d7
@ -27,7 +27,7 @@ class LogicalTokenBlock:
|
|||||||
def is_full(self) -> bool:
|
def is_full(self) -> bool:
|
||||||
return self.num_tokens == self.block_size
|
return self.num_tokens == self.block_size
|
||||||
|
|
||||||
def append(self, token_ids: List[int]) -> None:
|
def append_tokens(self, token_ids: List[int]) -> None:
|
||||||
assert len(token_ids) <= self.get_num_empty_slots()
|
assert len(token_ids) <= self.get_num_empty_slots()
|
||||||
self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids
|
self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids
|
||||||
self.num_tokens += len(token_ids)
|
self.num_tokens += len(token_ids)
|
||||||
|
|||||||
@ -97,15 +97,15 @@ class BlockSpaceManager:
|
|||||||
for seq in seq_group.seqs:
|
for seq in seq_group.seqs:
|
||||||
self.block_tables[seq.seq_id] = block_table.copy()
|
self.block_tables[seq.seq_id] = block_table.copy()
|
||||||
|
|
||||||
def can_append(self, seq_group: SequenceGroup) -> bool:
|
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
|
||||||
# Simple heuristic: If there is at least one free block
|
# Simple heuristic: If there is at least one free block
|
||||||
# for each sequence, we can append.
|
# for each sequence, we can append.
|
||||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||||
return num_seqs <= num_free_gpu_blocks
|
return num_seqs <= num_free_gpu_blocks
|
||||||
|
|
||||||
def append(self, seq: Sequence) -> Optional[Tuple[int, int]]:
|
def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
|
||||||
"""Allocate a physical slot for the new token."""
|
"""Allocate a physical slot for a new token."""
|
||||||
logical_blocks = seq.logical_token_blocks
|
logical_blocks = seq.logical_token_blocks
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
|
||||||
@ -156,7 +156,7 @@ class BlockSpaceManager:
|
|||||||
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
|
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||||
# NOTE: Conservatively, we assume that every sequence will allocate
|
# NOTE: Conservatively, we assume that every sequence will allocate
|
||||||
# at least one free block right after the swap-in.
|
# at least one free block right after the swap-in.
|
||||||
# NOTE: This should match the logic in can_append().
|
# NOTE: This should match the logic in can_append_slot().
|
||||||
num_required_blocks = len(blocks) + num_swapped_seqs
|
num_required_blocks = len(blocks) + num_swapped_seqs
|
||||||
return num_free_blocks - num_required_blocks >= self.watermark_blocks
|
return num_free_blocks - num_required_blocks >= self.watermark_blocks
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from cacheflow.core.policy import PolicyFactory
|
|||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.sequence import Sequence
|
from cacheflow.sequence import Sequence
|
||||||
from cacheflow.sequence import SequenceGroup
|
from cacheflow.sequence import SequenceGroup
|
||||||
from cacheflow.sequence import SequenceGroupInputs
|
from cacheflow.sequence import SequenceGroupMetadata
|
||||||
from cacheflow.sequence import SequenceOutputs
|
from cacheflow.sequence import SequenceOutputs
|
||||||
from cacheflow.sequence import SequenceStatus
|
from cacheflow.sequence import SequenceStatus
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ class Scheduler:
|
|||||||
preempted: List[SequenceGroup] = []
|
preempted: List[SequenceGroup] = []
|
||||||
while self.running:
|
while self.running:
|
||||||
seq_group = self.running.pop(0)
|
seq_group = self.running.pop(0)
|
||||||
while not self.block_manager.can_append(seq_group):
|
while not self.block_manager.can_append_slot(seq_group):
|
||||||
if self.running:
|
if self.running:
|
||||||
# Preempt the lowest-priority sequence groups.
|
# Preempt the lowest-priority sequence groups.
|
||||||
victim_seq_group = self.running.pop(-1)
|
victim_seq_group = self.running.pop(-1)
|
||||||
@ -119,7 +119,7 @@ class Scheduler:
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Append new slots to the sequence group.
|
# Append new slots to the sequence group.
|
||||||
self._append(seq_group, blocks_to_copy)
|
self._append_slot(seq_group, blocks_to_copy)
|
||||||
running.append(seq_group)
|
running.append(seq_group)
|
||||||
self.running = running
|
self.running = running
|
||||||
|
|
||||||
@ -143,7 +143,7 @@ class Scheduler:
|
|||||||
|
|
||||||
seq_group = self.swapped.pop(0)
|
seq_group = self.swapped.pop(0)
|
||||||
self._swap_in(seq_group, blocks_to_swap_in)
|
self._swap_in(seq_group, blocks_to_swap_in)
|
||||||
self._append(seq_group, blocks_to_copy)
|
self._append_slot(seq_group, blocks_to_copy)
|
||||||
self.running.append(seq_group)
|
self.running.append(seq_group)
|
||||||
|
|
||||||
num_batched_tokens = sum(
|
num_batched_tokens = sum(
|
||||||
@ -252,7 +252,7 @@ class Scheduler:
|
|||||||
prompt_group_ids = scheduler_output[3]
|
prompt_group_ids = scheduler_output[3]
|
||||||
|
|
||||||
# Create input data structures.
|
# Create input data structures.
|
||||||
input_seq_groups: List[SequenceGroupInputs] = []
|
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
updated_seq_groups: List[SequenceGroup] = self.running.copy()
|
updated_seq_groups: List[SequenceGroup] = self.running.copy()
|
||||||
|
|
||||||
for seq_group in self.running:
|
for seq_group in self.running:
|
||||||
@ -274,7 +274,7 @@ class Scheduler:
|
|||||||
# sequence length
|
# sequence length
|
||||||
seq_len = seq.get_len()
|
seq_len = seq.get_len()
|
||||||
|
|
||||||
input_seq_group = SequenceGroupInputs(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
is_prompt=is_prompt,
|
is_prompt=is_prompt,
|
||||||
input_tokens=input_tokens,
|
input_tokens=input_tokens,
|
||||||
@ -283,14 +283,14 @@ class Scheduler:
|
|||||||
sampling_params=self.sampling_params[group_id],
|
sampling_params=self.sampling_params[group_id],
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
)
|
)
|
||||||
input_seq_groups.append(input_seq_group)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
|
|
||||||
# Execute the first stage of the pipeline.
|
# Execute the first stage of the pipeline.
|
||||||
if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out:
|
if seq_group_metadata_list or blocks_to_swap_in or blocks_to_swap_out:
|
||||||
# Swap in and swap out should never happen at the same time.
|
# Swap in and swap out should never happen at the same time.
|
||||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||||
self.controllers[0].execute_stage(
|
self.controllers[0].execute_stage(
|
||||||
input_seq_groups,
|
seq_group_metadata_list,
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
blocks_to_copy=blocks_to_copy,
|
blocks_to_copy=blocks_to_copy,
|
||||||
@ -330,7 +330,7 @@ class Scheduler:
|
|||||||
|
|
||||||
# Append a new token to the sequence.
|
# Append a new token to the sequence.
|
||||||
output = seq_outputs[seq.seq_id]
|
output = seq_outputs[seq.seq_id]
|
||||||
seq.append(output.output_token, output.logprobs)
|
seq.append_token(output.output_token, output.logprobs)
|
||||||
|
|
||||||
# Check if the sequence has generated a stop token.
|
# Check if the sequence has generated a stop token.
|
||||||
if output.output_token in stop_token_ids:
|
if output.output_token in stop_token_ids:
|
||||||
@ -360,13 +360,13 @@ class Scheduler:
|
|||||||
if seq_group.group_id not in self.num_steps:
|
if seq_group.group_id not in self.num_steps:
|
||||||
self.num_steps[seq_group.group_id] = 0
|
self.num_steps[seq_group.group_id] = 0
|
||||||
|
|
||||||
def _append(
|
def _append_slot(
|
||||||
self,
|
self,
|
||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
ret = self.block_manager.append(seq)
|
ret = self.block_manager.append_slot(seq)
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
src_block, dst_block = ret
|
src_block, dst_block = ret
|
||||||
if src_block in blocks_to_copy:
|
if src_block in blocks_to_copy:
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Set, Dict
|
from typing import Dict, Set
|
||||||
|
|
||||||
|
|
||||||
class SamplingParams:
|
class SamplingParams:
|
||||||
@ -12,7 +12,6 @@ class SamplingParams:
|
|||||||
stop_token_ids: Set[int],
|
stop_token_ids: Set[int],
|
||||||
max_num_steps: int,
|
max_num_steps: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
context_window_size: Optional[int],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if n < 1:
|
if n < 1:
|
||||||
raise ValueError(f'n must be at least 1, got {n}.')
|
raise ValueError(f'n must be at least 1, got {n}.')
|
||||||
@ -27,10 +26,6 @@ class SamplingParams:
|
|||||||
if num_logprobs < 0:
|
if num_logprobs < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'num_logprobs must be non-negative, got {num_logprobs}.')
|
f'num_logprobs must be non-negative, got {num_logprobs}.')
|
||||||
if context_window_size is not None and context_window_size < 0:
|
|
||||||
raise ValueError(
|
|
||||||
'context_window_size must be non-negative, '
|
|
||||||
f'got {context_window_size}.')
|
|
||||||
|
|
||||||
if use_beam_search:
|
if use_beam_search:
|
||||||
if n == 1:
|
if n == 1:
|
||||||
@ -58,7 +53,6 @@ class SamplingParams:
|
|||||||
self.stop_token_ids = stop_token_ids
|
self.stop_token_ids = stop_token_ids
|
||||||
self.max_num_steps = max_num_steps
|
self.max_num_steps = max_num_steps
|
||||||
self.num_logprobs = num_logprobs
|
self.num_logprobs = num_logprobs
|
||||||
self.context_window_size = context_window_size
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f'SamplingParams(n={self.n}, '
|
return (f'SamplingParams(n={self.n}, '
|
||||||
@ -67,8 +61,7 @@ class SamplingParams:
|
|||||||
f'use_beam_search={self.use_beam_search}, '
|
f'use_beam_search={self.use_beam_search}, '
|
||||||
f'stop_token_ids={self.stop_token_ids}, '
|
f'stop_token_ids={self.stop_token_ids}, '
|
||||||
f'max_num_steps={self.max_num_steps}, '
|
f'max_num_steps={self.max_num_steps}, '
|
||||||
f'num_logprobs={self.num_logprobs}, '
|
f'num_logprobs={self.num_logprobs}')
|
||||||
f'context_window_size={self.context_window_size})')
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, d: Dict) -> 'SamplingParams':
|
def from_dict(cls, d: Dict) -> 'SamplingParams':
|
||||||
@ -80,5 +73,4 @@ class SamplingParams:
|
|||||||
stop_token_ids=set(d.get('stop_token_ids', set())),
|
stop_token_ids=set(d.get('stop_token_ids', set())),
|
||||||
max_num_steps=d.get('max_num_steps', 16),
|
max_num_steps=d.get('max_num_steps', 16),
|
||||||
num_logprobs=d.get('num_logprobs', 0),
|
num_logprobs=d.get('num_logprobs', 0),
|
||||||
context_window_size=d.get('context_window_size', None),
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -18,45 +18,46 @@ class Sequence:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
seq_id: int,
|
seq_id: int,
|
||||||
token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
block_size: int,
|
block_size: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_id = seq_id
|
self.seq_id = seq_id
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
|
||||||
|
self.prompt_len = len(prompt_token_ids)
|
||||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||||
# Initialize the logical token blocks with the given token ids.
|
# Initialize the logical token blocks with the prompt token ids.
|
||||||
self.add(token_ids)
|
self._append_tokens(prompt_token_ids)
|
||||||
|
|
||||||
self.prompt_len = len(token_ids)
|
|
||||||
self.status = SequenceStatus.WAITING
|
self.status = SequenceStatus.WAITING
|
||||||
|
# Used for beam search.
|
||||||
self.output_logprobs: List[Dict[int, float]] = []
|
self.output_logprobs: List[Dict[int, float]] = []
|
||||||
self.cumulative_logprobs = 0.0
|
self.cumulative_logprobs = 0.0
|
||||||
|
|
||||||
def add_block(self) -> None:
|
def _append_logical_block(self) -> None:
|
||||||
block = LogicalTokenBlock(
|
block = LogicalTokenBlock(
|
||||||
block_number=len(self.logical_token_blocks),
|
block_number=len(self.logical_token_blocks),
|
||||||
block_size=self.block_size,
|
block_size=self.block_size,
|
||||||
)
|
)
|
||||||
self.logical_token_blocks.append(block)
|
self.logical_token_blocks.append(block)
|
||||||
|
|
||||||
def add(self, token_ids: List[int]) -> None:
|
def _append_tokens(self, token_ids: List[int]) -> None:
|
||||||
while token_ids:
|
while token_ids:
|
||||||
if not self.logical_token_blocks:
|
if not self.logical_token_blocks:
|
||||||
self.add_block()
|
self._append_logical_block()
|
||||||
|
|
||||||
last_block = self.logical_token_blocks[-1]
|
last_block = self.logical_token_blocks[-1]
|
||||||
if last_block.is_full():
|
if last_block.is_full():
|
||||||
self.add_block()
|
self._append_logical_block()
|
||||||
last_block = self.logical_token_blocks[-1]
|
last_block = self.logical_token_blocks[-1]
|
||||||
|
|
||||||
num_empty_slots = last_block.get_num_empty_slots()
|
num_empty_slots = last_block.get_num_empty_slots()
|
||||||
last_block.append(token_ids[:num_empty_slots])
|
last_block.append_tokens(token_ids[:num_empty_slots])
|
||||||
token_ids = token_ids[num_empty_slots:]
|
token_ids = token_ids[num_empty_slots:]
|
||||||
|
|
||||||
def append(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
||||||
assert token_id in logprobs
|
assert token_id in logprobs
|
||||||
self.add([token_id])
|
self._append_tokens([token_id])
|
||||||
self.output_logprobs.append(logprobs)
|
self.output_logprobs.append(logprobs)
|
||||||
self.cumulative_logprobs += logprobs[token_id]
|
self.cumulative_logprobs += logprobs[token_id]
|
||||||
|
|
||||||
@ -121,7 +122,7 @@ class SequenceGroup:
|
|||||||
f'num_seqs={len(self.seqs)})')
|
f'num_seqs={len(self.seqs)})')
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroupInputs:
|
class SequenceGroupMetadata:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Union, Tuple, Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
@ -6,7 +6,6 @@ except ImportError:
|
|||||||
ray = None
|
ray = None
|
||||||
|
|
||||||
from cacheflow.core.scheduler import Scheduler
|
from cacheflow.core.scheduler import Scheduler
|
||||||
from cacheflow.sequence import SequenceGroupInputs
|
|
||||||
from cacheflow.worker.worker import Worker
|
from cacheflow.worker.worker import Worker
|
||||||
|
|
||||||
|
|
||||||
@ -81,23 +80,12 @@ class Controller:
|
|||||||
self.next_node = next_node
|
self.next_node = next_node
|
||||||
self.is_last_stage = isinstance(next_node, Scheduler)
|
self.is_last_stage = isinstance(next_node, Scheduler)
|
||||||
|
|
||||||
def execute_stage(
|
def execute_stage(self, *args, **kwargs) -> None:
|
||||||
self,
|
|
||||||
input_seq_groups: List[SequenceGroupInputs],
|
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
) -> None:
|
|
||||||
all_outputs = []
|
all_outputs = []
|
||||||
for worker in self.workers:
|
for worker in self.workers:
|
||||||
executor = (worker.execute_stage.remote
|
executor = (worker.execute_stage.remote
|
||||||
if self.use_ray else worker.execute_stage)
|
if self.use_ray else worker.execute_stage)
|
||||||
output = executor(
|
output = executor(*args, **kwargs)
|
||||||
input_seq_groups,
|
|
||||||
blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out,
|
|
||||||
blocks_to_copy,
|
|
||||||
)
|
|
||||||
all_outputs.append(output)
|
all_outputs.append(output)
|
||||||
|
|
||||||
if self.use_ray:
|
if self.use_ray:
|
||||||
|
|||||||
@ -8,10 +8,11 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
|
|||||||
initialize_all_reduce_launcher,
|
initialize_all_reduce_launcher,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.sequence import SequenceGroupInputs
|
from cacheflow.sequence import SequenceGroupMetadata
|
||||||
from cacheflow.sequence import SequenceOutputs
|
from cacheflow.sequence import SequenceOutputs
|
||||||
from cacheflow.worker.cache_engine import CacheEngine
|
from cacheflow.worker.cache_engine import CacheEngine
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -93,30 +94,29 @@ class Worker:
|
|||||||
|
|
||||||
def prepare_inputs(
|
def prepare_inputs(
|
||||||
self,
|
self,
|
||||||
input_seq_groups: List[SequenceGroupInputs],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
||||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||||
seq_logprobs: Dict[int, float] = {}
|
seq_logprobs: Dict[int, float] = {}
|
||||||
sampling_params: Dict[int, SamplingParams] = {}
|
|
||||||
input_tokens: List[int] = []
|
input_tokens: List[int] = []
|
||||||
input_positions: List[int] = []
|
input_positions: List[int] = []
|
||||||
slot_mapping: List[int] = []
|
slot_mapping: List[int] = []
|
||||||
|
|
||||||
# Add prompt tokens.
|
# Add prompt tokens.
|
||||||
prompt_lens: List[int] = []
|
prompt_lens: List[int] = []
|
||||||
for input_seq_group in input_seq_groups:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
if not input_seq_group.is_prompt:
|
if not seq_group_metadata.is_prompt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
seq_ids = list(input_seq_group.input_tokens.keys())
|
seq_ids = list(seq_group_metadata.input_tokens.keys())
|
||||||
sampling_params = input_seq_group.sampling_params
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
seq_groups.append((seq_ids, sampling_params))
|
seq_groups.append((seq_ids, sampling_params))
|
||||||
seq_logprobs.update(input_seq_group.seq_logprobs)
|
seq_logprobs.update(seq_group_metadata.seq_logprobs)
|
||||||
|
|
||||||
# Use any sequence in the group.
|
# Use any sequence in the group.
|
||||||
seq_id = seq_ids[0]
|
seq_id = seq_ids[0]
|
||||||
|
|
||||||
prompt_tokens = input_seq_group.input_tokens[seq_id]
|
prompt_tokens = seq_group_metadata.input_tokens[seq_id]
|
||||||
prompt_len = len(prompt_tokens)
|
prompt_len = len(prompt_tokens)
|
||||||
prompt_lens.append(prompt_len)
|
prompt_lens.append(prompt_len)
|
||||||
|
|
||||||
@ -126,7 +126,7 @@ class Worker:
|
|||||||
input_positions.extend(range(len(prompt_tokens)))
|
input_positions.extend(range(len(prompt_tokens)))
|
||||||
|
|
||||||
# Compute the slot mapping.
|
# Compute the slot mapping.
|
||||||
block_table = input_seq_group.block_tables[seq_id]
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
for i in range(prompt_len):
|
for i in range(prompt_len):
|
||||||
block_number = block_table[i // self.block_size]
|
block_number = block_table[i // self.block_size]
|
||||||
block_offset = i % self.block_size
|
block_offset = i % self.block_size
|
||||||
@ -138,31 +138,31 @@ class Worker:
|
|||||||
max_num_blocks_per_seq = 0
|
max_num_blocks_per_seq = 0
|
||||||
context_lens: List[int] = []
|
context_lens: List[int] = []
|
||||||
generation_block_tables: List[List[int]] = []
|
generation_block_tables: List[List[int]] = []
|
||||||
for input_seq_group in input_seq_groups:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
if input_seq_group.is_prompt:
|
if seq_group_metadata.is_prompt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
seq_ids = list(input_seq_group.input_tokens.keys())
|
seq_ids = list(seq_group_metadata.input_tokens.keys())
|
||||||
sampling_params = input_seq_group.sampling_params
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
seq_groups.append((seq_ids, sampling_params))
|
seq_groups.append((seq_ids, sampling_params))
|
||||||
seq_logprobs.update(input_seq_group.seq_logprobs)
|
seq_logprobs.update(seq_group_metadata.seq_logprobs)
|
||||||
|
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
assert len(input_seq_group.input_tokens[seq_id]) == 1
|
assert len(seq_group_metadata.input_tokens[seq_id]) == 1
|
||||||
generation_token = input_seq_group.input_tokens[seq_id][0]
|
generation_token = seq_group_metadata.input_tokens[seq_id][0]
|
||||||
input_tokens.append(generation_token)
|
input_tokens.append(generation_token)
|
||||||
|
|
||||||
position = input_seq_group.context_len - 1
|
position = seq_group_metadata.context_len - 1
|
||||||
input_positions.append(position)
|
input_positions.append(position)
|
||||||
|
|
||||||
block_table = input_seq_group.block_tables[seq_id]
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
generation_block_tables.append(block_table)
|
generation_block_tables.append(block_table)
|
||||||
|
|
||||||
max_context_len = max(
|
max_context_len = max(
|
||||||
max_context_len, input_seq_group.context_len)
|
max_context_len, seq_group_metadata.context_len)
|
||||||
max_num_blocks_per_seq = max(
|
max_num_blocks_per_seq = max(
|
||||||
max_num_blocks_per_seq, len(block_table))
|
max_num_blocks_per_seq, len(block_table))
|
||||||
context_lens.append(input_seq_group.context_len)
|
context_lens.append(seq_group_metadata.context_len)
|
||||||
|
|
||||||
block_number = block_table[position // self.block_size]
|
block_number = block_table[position // self.block_size]
|
||||||
block_offset = position % self.block_size
|
block_offset = position % self.block_size
|
||||||
@ -203,30 +203,30 @@ class Worker:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_stage(
|
def execute_stage(
|
||||||
self,
|
self,
|
||||||
input_seq_groups: List[SequenceGroupInputs],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
blocks_to_swap_in: Dict[int, int],
|
blocks_to_swap_in: Dict[int, int],
|
||||||
blocks_to_swap_out: Dict[int, int],
|
blocks_to_swap_out: Dict[int, int],
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> Dict[int, SequenceOutputs]:
|
||||||
# Issue cache operations.
|
# Issue cache operations.
|
||||||
command_issued = False
|
issued_cache_op = False
|
||||||
if blocks_to_swap_in:
|
if blocks_to_swap_in:
|
||||||
self.cache_engine.swap_in(blocks_to_swap_in)
|
self.cache_engine.swap_in(blocks_to_swap_in)
|
||||||
command_issued = True
|
issued_cache_op = True
|
||||||
if blocks_to_swap_out:
|
if blocks_to_swap_out:
|
||||||
self.cache_engine.swap_out(blocks_to_swap_out)
|
self.cache_engine.swap_out(blocks_to_swap_out)
|
||||||
command_issued = True
|
issued_cache_op = True
|
||||||
if blocks_to_copy:
|
if blocks_to_copy:
|
||||||
self.cache_engine.copy(blocks_to_copy)
|
self.cache_engine.copy(blocks_to_copy)
|
||||||
command_issued = True
|
issued_cache_op = True
|
||||||
|
|
||||||
if command_issued:
|
if issued_cache_op:
|
||||||
cache_events = self.cache_events
|
cache_events = self.cache_events
|
||||||
else:
|
else:
|
||||||
cache_events = None
|
cache_events = None
|
||||||
|
|
||||||
# If there is no input, we don't need to execute the model.
|
# If there is no input, we don't need to execute the model.
|
||||||
if not input_seq_groups:
|
if not seq_group_metadata_list:
|
||||||
if cache_events is not None:
|
if cache_events is not None:
|
||||||
for event in cache_events:
|
for event in cache_events:
|
||||||
event.wait()
|
event.wait()
|
||||||
@ -234,7 +234,7 @@ class Worker:
|
|||||||
|
|
||||||
# Prepare input tensors.
|
# Prepare input tensors.
|
||||||
input_tokens, input_positions, input_metadata = self.prepare_inputs(
|
input_tokens, input_positions, input_metadata = self.prepare_inputs(
|
||||||
input_seq_groups)
|
seq_group_metadata_list)
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
output = self.model(
|
output = self.model(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user