[Bugfix][Frontend] Reject guided decoding in multistep mode (#9892)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
b63c64d95b
commit
031a7995f3
@ -283,7 +283,7 @@ Feature x Feature
|
|||||||
- ✅
|
- ✅
|
||||||
- ✅
|
- ✅
|
||||||
- ✅
|
- ✅
|
||||||
- `✗ <https://github.com/vllm-project/vllm/issues/8985>`__
|
- `✗ <https://github.com/vllm-project/vllm/issues/9893>`__
|
||||||
- ?
|
- ?
|
||||||
- ✅
|
- ✅
|
||||||
- ✅
|
- ✅
|
||||||
|
|||||||
@ -35,3 +35,23 @@ async def test_out_of_vocab_token_ids():
|
|||||||
prompt=[999999],
|
prompt=[999999],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_multistep_with_guided_decoding():
|
||||||
|
model_name = "gpt2"
|
||||||
|
server_args = ["--enforce-eager", "--num-scheduler-steps", "8"]
|
||||||
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError,
|
||||||
|
match=re.compile(
|
||||||
|
'.*Guided decoding .* multi-step decoding.*')):
|
||||||
|
await client.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt="Hello",
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body={"response_format": {
|
||||||
|
"type": "json_object"
|
||||||
|
}})
|
||||||
|
|||||||
@ -829,6 +829,13 @@ class LLMEngine:
|
|||||||
raise ValueError(f"Got priority {priority} but "
|
raise ValueError(f"Got priority {priority} but "
|
||||||
"Priority scheduling is not enabled.")
|
"Priority scheduling is not enabled.")
|
||||||
|
|
||||||
|
if isinstance(params, SamplingParams) \
|
||||||
|
and (params.guided_decoding or params.logits_processors) \
|
||||||
|
and self.scheduler_config.num_scheduler_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Guided decoding and logits processors are not supported "
|
||||||
|
"in multi-step decoding")
|
||||||
|
|
||||||
if arrival_time is None:
|
if arrival_time is None:
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
|
|||||||
@ -485,8 +485,8 @@ class SamplingParams(
|
|||||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||||
"spaces_between_special_tokens="
|
"spaces_between_special_tokens="
|
||||||
f"{self.spaces_between_special_tokens}, "
|
f"{self.spaces_between_special_tokens}, "
|
||||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
|
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
|
||||||
f"guided_decoding={self.guided_decoding}")
|
f"guided_decoding={self.guided_decoding})")
|
||||||
|
|
||||||
|
|
||||||
class BeamSearchParams(
|
class BeamSearchParams(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user