Raise error for long prompt (#273)
This commit is contained in:
parent
598dc4b79a
commit
dafd924c1f
@ -186,14 +186,18 @@ class SchedulerConfig:
|
|||||||
a single iteration.
|
a single iteration.
|
||||||
max_num_seqs: Maximum number of sequences to be processed in a single
|
max_num_seqs: Maximum number of sequences to be processed in a single
|
||||||
iteration.
|
iteration.
|
||||||
|
max_seq_len: Maximum length of a sequence (including prompt
|
||||||
|
and generated text).
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
max_num_seqs: int,
|
max_num_seqs: int,
|
||||||
|
max_seq_len: int
|
||||||
) -> None:
|
) -> None:
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
self.max_num_seqs = max_num_seqs
|
self.max_num_seqs = max_num_seqs
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
|
|||||||
@ -102,11 +102,12 @@ class Scheduler:
|
|||||||
def get_num_unfinished_seq_groups(self) -> int:
|
def get_num_unfinished_seq_groups(self) -> int:
|
||||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||||
|
|
||||||
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
|
def _schedule(self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
|
||||||
# Blocks that need to be swaped or copied before model execution.
|
# Blocks that need to be swaped or copied before model execution.
|
||||||
blocks_to_swap_in: Dict[int, int] = {}
|
blocks_to_swap_in: Dict[int, int] = {}
|
||||||
blocks_to_swap_out: Dict[int, int] = {}
|
blocks_to_swap_out: Dict[int, int] = {}
|
||||||
blocks_to_copy: Dict[int, List[int]] = {}
|
blocks_to_copy: Dict[int, List[int]] = {}
|
||||||
|
ignored_seq_groups: List[SequenceGroup] = []
|
||||||
|
|
||||||
# Fix the current time.
|
# Fix the current time.
|
||||||
now = time.time()
|
now = time.time()
|
||||||
@ -187,12 +188,24 @@ class Scheduler:
|
|||||||
# If the sequence group has been preempted in this step, stop.
|
# If the sequence group has been preempted in this step, stop.
|
||||||
if seq_group in preempted:
|
if seq_group in preempted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||||
|
if num_prompt_tokens >= self.scheduler_config.max_seq_len:
|
||||||
|
logger.warn(
|
||||||
|
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||||
|
" and exceeds limit of "
|
||||||
|
f"{self.scheduler_config.max_seq_len}")
|
||||||
|
for seq in seq_group.get_seqs():
|
||||||
|
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||||
|
ignored_seq_groups.append(seq_group)
|
||||||
|
self.waiting.pop(0)
|
||||||
|
break
|
||||||
|
|
||||||
# If the sequence group cannot be allocated, stop.
|
# If the sequence group cannot be allocated, stop.
|
||||||
if not self.block_manager.can_allocate(seq_group):
|
if not self.block_manager.can_allocate(seq_group):
|
||||||
break
|
break
|
||||||
|
|
||||||
# If the number of batched tokens exceeds the limit, stop.
|
# If the number of batched tokens exceeds the limit, stop.
|
||||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
|
||||||
if (num_batched_tokens + num_prompt_tokens
|
if (num_batched_tokens + num_prompt_tokens
|
||||||
> self.scheduler_config.max_num_batched_tokens):
|
> self.scheduler_config.max_num_batched_tokens):
|
||||||
break
|
break
|
||||||
@ -218,7 +231,7 @@ class Scheduler:
|
|||||||
blocks_to_copy=blocks_to_copy,
|
blocks_to_copy=blocks_to_copy,
|
||||||
)
|
)
|
||||||
if not self.log_stats:
|
if not self.log_stats:
|
||||||
return scheduler_outputs, prompt_group_ids
|
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
|
||||||
|
|
||||||
# TODO(woosuk): Move the below code to the engine.
|
# TODO(woosuk): Move the below code to the engine.
|
||||||
now = time.time()
|
now = time.time()
|
||||||
@ -258,13 +271,13 @@ class Scheduler:
|
|||||||
f"Pending: {len(self.waiting)} reqs, "
|
f"Pending: {len(self.waiting)} reqs, "
|
||||||
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
||||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||||
return scheduler_outputs, prompt_group_ids
|
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
|
||||||
|
|
||||||
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
|
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[SequenceGroup]]:
|
||||||
# Schedule sequence groups.
|
# Schedule sequence groups.
|
||||||
# This function call changes the internal states of the scheduler
|
# This function call changes the internal states of the scheduler
|
||||||
# such as self.running, self.swapped, and self.waiting.
|
# such as self.running, self.swapped, and self.waiting.
|
||||||
scheduler_outputs, prompt_group_ids = self._schedule()
|
scheduler_outputs, prompt_group_ids, ignored_seq_groups = self._schedule()
|
||||||
|
|
||||||
# Create input data structures.
|
# Create input data structures.
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
@ -286,7 +299,7 @@ class Scheduler:
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
)
|
)
|
||||||
seq_group_metadata_list.append(seq_group_metadata)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
return seq_group_metadata_list, scheduler_outputs
|
return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -123,8 +123,12 @@ class EngineArgs:
|
|||||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||||
self.tensor_parallel_size,
|
self.tensor_parallel_size,
|
||||||
self.worker_use_ray)
|
self.worker_use_ray)
|
||||||
|
max_seq_len = min(
|
||||||
|
self.max_num_batched_tokens,
|
||||||
|
getattr(model_config.hf_config, "max_position_embeddings",
|
||||||
|
float("inf")))
|
||||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||||
self.max_num_seqs)
|
self.max_num_seqs, max_seq_len)
|
||||||
return model_config, cache_config, parallel_config, scheduler_config
|
return model_config, cache_config, parallel_config, scheduler_config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -226,8 +226,8 @@ class LLMEngine:
|
|||||||
and updates the scheduler with the model outputs. Finally, it decodes
|
and updates the scheduler with the model outputs. Finally, it decodes
|
||||||
the sequences and returns the newly generated results.
|
the sequences and returns the newly generated results.
|
||||||
"""
|
"""
|
||||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule()
|
||||||
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
|
if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups):
|
||||||
# Nothing to do.
|
# Nothing to do.
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -251,7 +251,7 @@ class LLMEngine:
|
|||||||
|
|
||||||
# Create the outputs.
|
# Create the outputs.
|
||||||
request_outputs: List[RequestOutput] = []
|
request_outputs: List[RequestOutput] = []
|
||||||
for seq_group in seq_groups:
|
for seq_group in seq_groups + ignored_seq_groups:
|
||||||
request_output = RequestOutput.from_seq_group(seq_group)
|
request_output = RequestOutput.from_seq_group(seq_group)
|
||||||
request_outputs.append(request_output)
|
request_outputs.append(request_output)
|
||||||
return request_outputs
|
return request_outputs
|
||||||
@ -288,6 +288,12 @@ class LLMEngine:
|
|||||||
if stopped:
|
if stopped:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Check if the sequence has reached max_seq_len.
|
||||||
|
if (seq.get_len() >=
|
||||||
|
self.scheduler.scheduler_config.max_seq_len):
|
||||||
|
self.scheduler.free_seq(
|
||||||
|
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
||||||
|
continue
|
||||||
# Check if the sequence has reached max_tokens.
|
# Check if the sequence has reached max_tokens.
|
||||||
if seq.get_output_len() == sampling_params.max_tokens:
|
if seq.get_output_len() == sampling_params.max_tokens:
|
||||||
self.scheduler.free_seq(
|
self.scheduler.free_seq(
|
||||||
|
|||||||
@ -13,6 +13,7 @@ class SequenceStatus(enum.Enum):
|
|||||||
FINISHED_STOPPED = enum.auto()
|
FINISHED_STOPPED = enum.auto()
|
||||||
FINISHED_LENGTH_CAPPED = enum.auto()
|
FINISHED_LENGTH_CAPPED = enum.auto()
|
||||||
FINISHED_ABORTED = enum.auto()
|
FINISHED_ABORTED = enum.auto()
|
||||||
|
FINISHED_IGNORED = enum.auto()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_finished(status: "SequenceStatus") -> bool:
|
def is_finished(status: "SequenceStatus") -> bool:
|
||||||
@ -20,6 +21,7 @@ class SequenceStatus(enum.Enum):
|
|||||||
SequenceStatus.FINISHED_STOPPED,
|
SequenceStatus.FINISHED_STOPPED,
|
||||||
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
||||||
SequenceStatus.FINISHED_ABORTED,
|
SequenceStatus.FINISHED_ABORTED,
|
||||||
|
SequenceStatus.FINISHED_IGNORED
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -30,6 +32,8 @@ class SequenceStatus(enum.Enum):
|
|||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
elif status == SequenceStatus.FINISHED_ABORTED:
|
elif status == SequenceStatus.FINISHED_ABORTED:
|
||||||
finish_reason = "abort"
|
finish_reason = "abort"
|
||||||
|
elif status == SequenceStatus.FINISHED_IGNORED:
|
||||||
|
finish_reason = "length"
|
||||||
else:
|
else:
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
return finish_reason
|
return finish_reason
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user