From b1c255630db60e08c394964b8ed6c0154d31a29f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 30 May 2024 07:05:01 +0800 Subject: [PATCH] [Core] Avoid the need to pass `None` values to `Sequence.inputs` (#5099) --- tests/core/test_block_manager.py | 2 -- tests/core/utils.py | 7 +------ tests/engine/output_processor/test_stop_checker.py | 6 +----- tests/test_cache_block_hashing.py | 1 - tests/tokenization/test_detokenize.py | 1 - vllm/inputs.py | 4 ++-- vllm/sequence.py | 4 ++-- 7 files changed, 6 insertions(+), 19 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index ddd84317..cd306b9e 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -234,7 +234,6 @@ def test_append_slot_cow(): inputs={ "prompt": "one two three", "prompt_token_ids": [1, 2, 3], - "multi_modal_data": None }, block_size=block_size) @@ -525,7 +524,6 @@ def test_sliding_window_multi_seq(): inputs={ "prompt": "one two three", "prompt_token_ids": [0, 1, 2], - "multi_modal_data": None }, block_size=block_size) seq_group = SequenceGroup(request_id="1", diff --git a/tests/core/utils.py b/tests/core/utils.py index cd2045b8..2fbf099c 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -25,7 +25,6 @@ def create_dummy_prompt( inputs={ "prompt": prompt_str, "prompt_token_ids": prompt_tokens, - "multi_modal_data": None, }, block_size=block_size) seq_group = SequenceGroup(request_id=request_id, @@ -103,11 +102,7 @@ def create_seq_group( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - inputs={ - "prompt": "", - "prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, - }, + inputs={"prompt_token_ids": prompt_token_ids}, block_size=16, ) diff --git a/tests/engine/output_processor/test_stop_checker.py b/tests/engine/output_processor/test_stop_checker.py index 1d9c878d..f795403e 100644 --- a/tests/engine/output_processor/test_stop_checker.py +++ b/tests/engine/output_processor/test_stop_checker.py @@ -15,11 +15,7 @@ def sequence_with_eos(text: str, eos_token: str, """ seq = Sequence( seq_id=0, - inputs={ - "prompt": "", - "prompt_token_ids": [], - "multi_modal_data": None, - }, + inputs={"prompt_token_ids": []}, block_size=16, eos_token_id=eos_token_id, ) diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 97864af8..0fbe3dae 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -74,7 +74,6 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, inputs={ "prompt": prompt, "prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, }, block_size=block_size, eos_token_id=tokenizer.tokenizer.eos_token_id, diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 1d4c74d6..8d019fe5 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -126,7 +126,6 @@ def create_sequence(prompt_token_ids=None): inputs={ "prompt": "", "prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, }, block_size=16, ) diff --git a/vllm/inputs.py b/vllm/inputs.py index f5d99b1b..85c9cd84 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -126,5 +126,5 @@ PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt] class LLMInputs(TypedDict): prompt_token_ids: List[int] - prompt: Optional[str] - multi_modal_data: Optional["MultiModalData"] + prompt: NotRequired[Optional[str]] + multi_modal_data: NotRequired[Optional["MultiModalData"]] diff --git a/vllm/sequence.py b/vllm/sequence.py index ee8c94bb..ac5c234d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -249,7 +249,7 @@ class Sequence: @property def prompt(self) -> Optional[str]: - return self.inputs["prompt"] + return self.inputs.get("prompt") @property def prompt_token_ids(self) -> List[int]: @@ -257,7 +257,7 @@ class Sequence: @property def multi_modal_data(self) -> Optional["MultiModalData"]: - return self.inputs["multi_modal_data"] + return self.inputs.get("multi_modal_data") @property def lora_int_id(self) -> int: