[BugFix] Fix server crash on empty prompt (#7746)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
parent
faeddb565d
commit
e25fee57c2
9
tests/entrypoints/llm/test_prompt_validation.py
Normal file
9
tests/entrypoints/llm/test_prompt_validation.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_prompt():
|
||||||
|
llm = LLM(model="gpt2")
|
||||||
|
with pytest.raises(ValueError, match='Prompt cannot be empty'):
|
||||||
|
llm.generate([""])
|
||||||
22
tests/entrypoints/openai/test_prompt_validation.py
Normal file
22
tests/entrypoints/openai/test_prompt_validation.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# imports for guided decoding tests
|
||||||
|
import re
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_prompt():
|
||||||
|
model_name = "gpt2"
|
||||||
|
server_args = ["--enforce-eager"]
|
||||||
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError,
|
||||||
|
match=re.compile('.+Prompt cannot be empty.+')):
|
||||||
|
await client.completions.create(model=model_name,
|
||||||
|
prompt="",
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0)
|
||||||
@ -591,6 +591,7 @@ class LLMEngine:
|
|||||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self._validate_model_inputs(processed_inputs)
|
||||||
# Create the sequences.
|
# Create the sequences.
|
||||||
block_size = self.cache_config.block_size
|
block_size = self.cache_config.block_size
|
||||||
seq_id = next(self.seq_counter)
|
seq_id = next(self.seq_counter)
|
||||||
@ -1647,3 +1648,10 @@ class LLMEngine:
|
|||||||
|
|
||||||
def is_embedding_model(self):
|
def is_embedding_model(self):
|
||||||
return self.model_config.is_embedding_model
|
return self.model_config.is_embedding_model
|
||||||
|
|
||||||
|
def _validate_model_inputs(self, inputs: Union[LLMInputs,
|
||||||
|
EncoderDecoderLLMInputs]):
|
||||||
|
prompt_key = "encoder_prompt_token_ids" \
|
||||||
|
if self.is_encoder_decoder_model() else "prompt_token_ids"
|
||||||
|
if not inputs.get(prompt_key):
|
||||||
|
raise ValueError("Prompt cannot be empty")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user