diff --git a/vllm/config.py b/vllm/config.py index ebbc7168..c2eba07d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -186,14 +186,18 @@ class SchedulerConfig: a single iteration. max_num_seqs: Maximum number of sequences to be processed in a single iteration. + max_seq_len: Maximum length of a sequence (including prompt + and generated text). """ def __init__( self, max_num_batched_tokens: int, max_num_seqs: int, + max_seq_len: int ) -> None: self.max_num_batched_tokens = max_num_batched_tokens self.max_num_seqs = max_num_seqs + self.max_seq_len = max_seq_len _STR_DTYPE_TO_TORCH_DTYPE = { diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index a1fb02db..7160c4f4 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -102,11 +102,12 @@ class Scheduler: def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) - def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: + def _schedule(self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]: # Blocks that need to be swaped or copied before model execution. blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} blocks_to_copy: Dict[int, List[int]] = {} + ignored_seq_groups: List[SequenceGroup] = [] # Fix the current time. now = time.time() @@ -187,12 +188,24 @@ class Scheduler: # If the sequence group has been preempted in this step, stop. if seq_group in preempted: break + + num_prompt_tokens = seq_group.get_seqs()[0].get_len() + if num_prompt_tokens >= self.scheduler_config.max_seq_len: + logger.warn( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + " and exceeds limit of " + f"{self.scheduler_config.max_seq_len}") + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + self.waiting.pop(0) + break + # If the sequence group cannot be allocated, stop. if not self.block_manager.can_allocate(seq_group): break # If the number of batched tokens exceeds the limit, stop. - num_prompt_tokens = seq_group.get_seqs()[0].get_len() if (num_batched_tokens + num_prompt_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -218,7 +231,7 @@ class Scheduler: blocks_to_copy=blocks_to_copy, ) if not self.log_stats: - return scheduler_outputs, prompt_group_ids + return scheduler_outputs, prompt_group_ids, ignored_seq_groups # TODO(woosuk): Move the below code to the engine. now = time.time() @@ -258,13 +271,13 @@ class Scheduler: f"Pending: {len(self.waiting)} reqs, " f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") - return scheduler_outputs, prompt_group_ids + return scheduler_outputs, prompt_group_ids, ignored_seq_groups - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: + def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[SequenceGroup]]: # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. - scheduler_outputs, prompt_group_ids = self._schedule() + scheduler_outputs, prompt_group_ids, ignored_seq_groups = self._schedule() # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] @@ -286,7 +299,7 @@ class Scheduler: block_tables=block_tables, ) seq_group_metadata_list.append(seq_group_metadata) - return seq_group_metadata_list, scheduler_outputs + return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups def update( self, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 23bdcb3a..e32ef005 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -123,8 +123,12 @@ class EngineArgs: parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray) + max_seq_len = min( + self.max_num_batched_tokens, + getattr(model_config.hf_config, "max_position_embeddings", + float("inf"))) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs) + self.max_num_seqs, max_seq_len) return model_config, cache_config, parallel_config, scheduler_config diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d4397ee3..c9b4215d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -226,8 +226,8 @@ class LLMEngine: and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - if (not seq_group_metadata_list) and scheduler_outputs.is_empty(): + seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule() + if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups): # Nothing to do. return [] @@ -251,7 +251,7 @@ class LLMEngine: # Create the outputs. request_outputs: List[RequestOutput] = [] - for seq_group in seq_groups: + for seq_group in seq_groups + ignored_seq_groups: request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) return request_outputs @@ -288,6 +288,12 @@ class LLMEngine: if stopped: continue + # Check if the sequence has reached max_seq_len. + if (seq.get_len() >= + self.scheduler.scheduler_config.max_seq_len): + self.scheduler.free_seq( + seq, SequenceStatus.FINISHED_LENGTH_CAPPED) + continue # Check if the sequence has reached max_tokens. if seq.get_output_len() == sampling_params.max_tokens: self.scheduler.free_seq( diff --git a/vllm/sequence.py b/vllm/sequence.py index 5fe84729..083e6d03 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -13,6 +13,7 @@ class SequenceStatus(enum.Enum): FINISHED_STOPPED = enum.auto() FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_ABORTED = enum.auto() + FINISHED_IGNORED = enum.auto() @staticmethod def is_finished(status: "SequenceStatus") -> bool: @@ -20,6 +21,7 @@ class SequenceStatus(enum.Enum): SequenceStatus.FINISHED_STOPPED, SequenceStatus.FINISHED_LENGTH_CAPPED, SequenceStatus.FINISHED_ABORTED, + SequenceStatus.FINISHED_IGNORED ] @staticmethod @@ -30,6 +32,8 @@ class SequenceStatus(enum.Enum): finish_reason = "length" elif status == SequenceStatus.FINISHED_ABORTED: finish_reason = "abort" + elif status == SequenceStatus.FINISHED_IGNORED: + finish_reason = "length" else: finish_reason = None return finish_reason