[Bugfix][Frontend] Reject guided decoding in multistep mode (#9892)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde 2024-10-31 19:09:46 -06:00 committed by GitHub
parent b63c64d95b
commit 031a7995f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 30 additions and 3 deletions

View File

@ -283,7 +283,7 @@ Feature x Feature
- ✅
- ✅
- ✅
- `✗ <https://github.com/vllm-project/vllm/issues/8985>`__
- `✗ <https://github.com/vllm-project/vllm/issues/9893>`__
- ?
- ✅
- ✅

View File

@ -35,3 +35,23 @@ async def test_out_of_vocab_token_ids():
prompt=[999999],
max_tokens=5,
temperature=0.0)
@pytest.mark.asyncio
async def test_reject_multistep_with_guided_decoding():
model_name = "gpt2"
server_args = ["--enforce-eager", "--num-scheduler-steps", "8"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
with pytest.raises(openai.BadRequestError,
match=re.compile(
'.*Guided decoding .* multi-step decoding.*')):
await client.completions.create(
model=model_name,
prompt="Hello",
max_tokens=5,
temperature=0.0,
extra_body={"response_format": {
"type": "json_object"
}})

View File

@ -829,6 +829,13 @@ class LLMEngine:
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
if isinstance(params, SamplingParams) \
and (params.guided_decoding or params.logits_processors) \
and self.scheduler_config.num_scheduler_steps > 1:
raise ValueError(
"Guided decoding and logits processors are not supported "
"in multi-step decoding")
if arrival_time is None:
arrival_time = time.time()

View File

@ -485,8 +485,8 @@ class SamplingParams(
f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
f"guided_decoding={self.guided_decoding}")
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
f"guided_decoding={self.guided_decoding})")
class BeamSearchParams(