Raise error for long prompt (#273)

This commit is contained in:
Lily Liu 2023-06-30 18:48:49 -07:00 committed by GitHub
parent 598dc4b79a
commit dafd924c1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 11 deletions

View File

@ -186,14 +186,18 @@ class SchedulerConfig:
a single iteration. a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single max_num_seqs: Maximum number of sequences to be processed in a single
iteration. iteration.
max_seq_len: Maximum length of a sequence (including prompt
and generated text).
""" """
def __init__( def __init__(
self, self,
max_num_batched_tokens: int, max_num_batched_tokens: int,
max_num_seqs: int, max_num_seqs: int,
max_seq_len: int
) -> None: ) -> None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_seq_len = max_seq_len
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {

View File

@ -102,11 +102,12 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: def _schedule(self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {} blocks_to_copy: Dict[int, List[int]] = {}
ignored_seq_groups: List[SequenceGroup] = []
# Fix the current time. # Fix the current time.
now = time.time() now = time.time()
@ -187,12 +188,24 @@ class Scheduler:
# If the sequence group has been preempted in this step, stop. # If the sequence group has been preempted in this step, stop.
if seq_group in preempted: if seq_group in preempted:
break break
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens >= self.scheduler_config.max_seq_len:
logger.warn(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
" and exceeds limit of "
f"{self.scheduler_config.max_seq_len}")
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
break
# If the sequence group cannot be allocated, stop. # If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group): if not self.block_manager.can_allocate(seq_group):
break break
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if (num_batched_tokens + num_prompt_tokens if (num_batched_tokens + num_prompt_tokens
> self.scheduler_config.max_num_batched_tokens): > self.scheduler_config.max_num_batched_tokens):
break break
@ -218,7 +231,7 @@ class Scheduler:
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
) )
if not self.log_stats: if not self.log_stats:
return scheduler_outputs, prompt_group_ids return scheduler_outputs, prompt_group_ids, ignored_seq_groups
# TODO(woosuk): Move the below code to the engine. # TODO(woosuk): Move the below code to the engine.
now = time.time() now = time.time()
@ -258,13 +271,13 @@ class Scheduler:
f"Pending: {len(self.waiting)} reqs, " f"Pending: {len(self.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
return scheduler_outputs, prompt_group_ids return scheduler_outputs, prompt_group_ids, ignored_seq_groups
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[SequenceGroup]]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting. # such as self.running, self.swapped, and self.waiting.
scheduler_outputs, prompt_group_ids = self._schedule() scheduler_outputs, prompt_group_ids, ignored_seq_groups = self._schedule()
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
@ -286,7 +299,7 @@ class Scheduler:
block_tables=block_tables, block_tables=block_tables,
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups
def update( def update(
self, self,

View File

@ -123,8 +123,12 @@ class EngineArgs:
parallel_config = ParallelConfig(self.pipeline_parallel_size, parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size,
self.worker_use_ray) self.worker_use_ray)
max_seq_len = min(
self.max_num_batched_tokens,
getattr(model_config.hf_config, "max_position_embeddings",
float("inf")))
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs) self.max_num_seqs, max_seq_len)
return model_config, cache_config, parallel_config, scheduler_config return model_config, cache_config, parallel_config, scheduler_config

View File

@ -226,8 +226,8 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty(): if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups):
# Nothing to do. # Nothing to do.
return [] return []
@ -251,7 +251,7 @@ class LLMEngine:
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
for seq_group in seq_groups: for seq_group in seq_groups + ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
return request_outputs return request_outputs
@ -288,6 +288,12 @@ class LLMEngine:
if stopped: if stopped:
continue continue
# Check if the sequence has reached max_seq_len.
if (seq.get_len() >=
self.scheduler.scheduler_config.max_seq_len):
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
# Check if the sequence has reached max_tokens. # Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens: if seq.get_output_len() == sampling_params.max_tokens:
self.scheduler.free_seq( self.scheduler.free_seq(

View File

@ -13,6 +13,7 @@ class SequenceStatus(enum.Enum):
FINISHED_STOPPED = enum.auto() FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto() FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
@staticmethod @staticmethod
def is_finished(status: "SequenceStatus") -> bool: def is_finished(status: "SequenceStatus") -> bool:
@ -20,6 +21,7 @@ class SequenceStatus(enum.Enum):
SequenceStatus.FINISHED_STOPPED, SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED, SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED, SequenceStatus.FINISHED_ABORTED,
SequenceStatus.FINISHED_IGNORED
] ]
@staticmethod @staticmethod
@ -30,6 +32,8 @@ class SequenceStatus(enum.Enum):
finish_reason = "length" finish_reason = "length"
elif status == SequenceStatus.FINISHED_ABORTED: elif status == SequenceStatus.FINISHED_ABORTED:
finish_reason = "abort" finish_reason = "abort"
elif status == SequenceStatus.FINISHED_IGNORED:
finish_reason = "length"
else: else:
finish_reason = None finish_reason = None
return finish_reason return finish_reason