[Misc] add fixture to guided processor tests (#6341)

This commit is contained in:
Yihuan Bu 2024-07-12 12:55:39 -04:00 committed by GitHub
parent f9d25c2519
commit b039cbbce3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 144 additions and 206 deletions

View File

@ -0,0 +1,69 @@
import pytest
@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
@pytest.fixture
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}
@pytest.fixture
def sample_guided_choice():
return [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
"Ruby", "Swift", "Kotlin"
]
@pytest.fixture
def sample_sql_statements():
return ("""
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
""")

View File

@ -22,53 +22,6 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# generation quality here # generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora" LORA_NAME = "typeof/zephyr-7b-beta-lora"
TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
TEST_CHOICE = [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
"Swift", "Kotlin"
]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def zephyr_lora_files(): def zephyr_lora_files():
@ -408,7 +361,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(client: openai.AsyncOpenAI, async def test_guided_choice_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str,
sample_guided_choice):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -422,10 +376,10 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE, extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend)) guided_decoding_backend=guided_decoding_backend))
choice1 = chat_completion.choices[0].message.content choice1 = chat_completion.choices[0].message.content
assert choice1 in TEST_CHOICE assert choice1 in sample_guided_choice
messages.append({"role": "assistant", "content": choice1}) messages.append({"role": "assistant", "content": choice1})
messages.append({ messages.append({
@ -436,10 +390,10 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE, extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend)) guided_decoding_backend=guided_decoding_backend))
choice2 = chat_completion.choices[0].message.content choice2 = chat_completion.choices[0].message.content
assert choice2 in TEST_CHOICE assert choice2 in sample_guided_choice
assert choice1 != choice2 assert choice1 != choice2
@ -447,7 +401,8 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(client: openai.AsyncOpenAI, async def test_guided_json_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str,
sample_json_schema):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -456,18 +411,18 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
"user", "user",
"content": "content":
f"Give an example JSON for an employee profile that " f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}" f"fits this schema: {sample_json_schema}"
}] }]
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_tokens=1000, max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA, extra_body=dict(guided_json=sample_json_schema,
guided_decoding_backend=guided_decoding_backend)) guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
assert message.content is not None assert message.content is not None
json1 = json.loads(message.content) json1 = json.loads(message.content)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA) jsonschema.validate(instance=json1, schema=sample_json_schema)
messages.append({"role": "assistant", "content": message.content}) messages.append({"role": "assistant", "content": message.content})
messages.append({ messages.append({
@ -480,12 +435,12 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_tokens=1000, max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA, extra_body=dict(guided_json=sample_json_schema,
guided_decoding_backend=guided_decoding_backend)) guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
assert message.content is not None assert message.content is not None
json2 = json.loads(message.content) json2 = json.loads(message.content)
jsonschema.validate(instance=json2, schema=TEST_SCHEMA) jsonschema.validate(instance=json2, schema=sample_json_schema)
assert json1["name"] != json2["name"] assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"] assert json1["age"] != json2["age"]
@ -494,7 +449,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(client: openai.AsyncOpenAI, async def test_guided_regex_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str, sample_regex):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -502,17 +457,17 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI,
"role": "role":
"user", "user",
"content": "content":
f"Give an example IP address with this regex: {TEST_REGEX}" f"Give an example IP address with this regex: {sample_regex}"
}] }]
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_tokens=20, max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX, extra_body=dict(guided_regex=sample_regex,
guided_decoding_backend=guided_decoding_backend)) guided_decoding_backend=guided_decoding_backend))
ip1 = chat_completion.choices[0].message.content ip1 = chat_completion.choices[0].message.content
assert ip1 is not None assert ip1 is not None
assert re.fullmatch(TEST_REGEX, ip1) is not None assert re.fullmatch(sample_regex, ip1) is not None
messages.append({"role": "assistant", "content": ip1}) messages.append({"role": "assistant", "content": ip1})
messages.append({"role": "user", "content": "Give me a different one"}) messages.append({"role": "user", "content": "Give me a different one"})
@ -520,11 +475,11 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_tokens=20, max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX, extra_body=dict(guided_regex=sample_regex,
guided_decoding_backend=guided_decoding_backend)) guided_decoding_backend=guided_decoding_backend))
ip2 = chat_completion.choices[0].message.content ip2 = chat_completion.choices[0].message.content
assert ip2 is not None assert ip2 is not None
assert re.fullmatch(TEST_REGEX, ip2) is not None assert re.fullmatch(sample_regex, ip2) is not None
assert ip1 != ip2 assert ip1 != ip2
@ -553,7 +508,8 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str,
sample_guided_choice):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -569,7 +525,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
max_tokens=10, max_tokens=10,
logprobs=True, logprobs=True,
top_logprobs=5, top_logprobs=5,
extra_body=dict(guided_choice=TEST_CHOICE, extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend)) guided_decoding_backend=guided_decoding_backend))
assert chat_completion.choices[0].logprobs is not None assert chat_completion.choices[0].logprobs is not None
@ -585,7 +541,8 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_named_tool_use(client: openai.AsyncOpenAI, async def test_named_tool_use(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str,
sample_json_schema):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -594,7 +551,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"user", "user",
"content": "content":
f"Give an example JSON for an employee profile that " f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}" f"fits this schema: {sample_json_schema}"
}] }]
# non-streaming # non-streaming
@ -608,7 +565,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"function": { "function": {
"name": "dummy_function_name", "name": "dummy_function_name",
"description": "This is a dummy function", "description": "This is a dummy function",
"parameters": TEST_SCHEMA "parameters": sample_json_schema
} }
}], }],
tool_choice={ tool_choice={
@ -621,7 +578,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
assert len(message.content) == 0 assert len(message.content) == 0
json_string = message.tool_calls[0].function.arguments json_string = message.tool_calls[0].function.arguments
json1 = json.loads(json_string) json1 = json.loads(json_string)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA) jsonschema.validate(instance=json1, schema=sample_json_schema)
messages.append({"role": "assistant", "content": json_string}) messages.append({"role": "assistant", "content": json_string})
messages.append({ messages.append({
@ -642,7 +599,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"function": { "function": {
"name": "dummy_function_name", "name": "dummy_function_name",
"description": "This is a dummy function", "description": "This is a dummy function",
"parameters": TEST_SCHEMA "parameters": sample_json_schema
} }
}], }],
tool_choice={ tool_choice={
@ -667,7 +624,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
# finish reason should only return in last block # finish reason should only return in last block
assert finish_reason_count == 1 assert finish_reason_count == 1
json2 = json.loads("".join(output)) json2 = json.loads("".join(output))
jsonschema.validate(instance=json2, schema=TEST_SCHEMA) jsonschema.validate(instance=json2, schema=sample_json_schema)
assert json1["name"] != json2["name"] assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"] assert json1["age"] != json2["age"]
@ -675,7 +632,8 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"]) @pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_required_tool_use_not_yet_supported( async def test_required_tool_use_not_yet_supported(
client: openai.AsyncOpenAI, guided_decoding_backend: str): client: openai.AsyncOpenAI, guided_decoding_backend: str,
sample_json_schema):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -684,7 +642,7 @@ async def test_required_tool_use_not_yet_supported(
"user", "user",
"content": "content":
f"Give an example JSON for an employee profile that " f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}" f"fits this schema: {sample_json_schema}"
}] }]
with pytest.raises(openai.BadRequestError): with pytest.raises(openai.BadRequestError):
@ -697,7 +655,7 @@ async def test_required_tool_use_not_yet_supported(
"function": { "function": {
"name": "dummy_function_name", "name": "dummy_function_name",
"description": "This is a dummy function", "description": "This is a dummy function",
"parameters": TEST_SCHEMA "parameters": sample_json_schema
} }
}], }],
tool_choice="required") tool_choice="required")
@ -712,7 +670,7 @@ async def test_required_tool_use_not_yet_supported(
"function": { "function": {
"name": "dummy_function_name", "name": "dummy_function_name",
"description": "This is a dummy function", "description": "This is a dummy function",
"parameters": TEST_SCHEMA "parameters": sample_json_schema
} }
}], }],
tool_choice="auto") tool_choice="auto")
@ -720,8 +678,9 @@ async def test_required_tool_use_not_yet_supported(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"]) @pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_inconsistent_tool_choice_and_tools( async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
client: openai.AsyncOpenAI, guided_decoding_backend: str): guided_decoding_backend: str,
sample_json_schema):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -730,7 +689,7 @@ async def test_inconsistent_tool_choice_and_tools(
"user", "user",
"content": "content":
f"Give an example JSON for an employee profile that " f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}" f"fits this schema: {sample_json_schema}"
}] }]
with pytest.raises(openai.BadRequestError): with pytest.raises(openai.BadRequestError):
@ -755,7 +714,7 @@ async def test_inconsistent_tool_choice_and_tools(
"function": { "function": {
"name": "dummy_function_name", "name": "dummy_function_name",
"description": "This is a dummy function", "description": "This is a dummy function",
"parameters": TEST_SCHEMA "parameters": sample_json_schema
} }
}], }],
tool_choice={ tool_choice={

View File

@ -24,53 +24,6 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# generation quality here # generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora" LORA_NAME = "typeof/zephyr-7b-beta-lora"
TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
TEST_CHOICE = [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
"Swift", "Kotlin"
]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def zephyr_lora_files(): def zephyr_lora_files():
@ -529,77 +482,71 @@ async def test_logits_bias(client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_json_completion(client: openai.AsyncOpenAI, async def test_guided_json_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str,
sample_json_schema):
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
prompt=f"Give an example JSON for an employee profile " prompt=f"Give an example JSON for an employee profile "
f"that fits this schema: {TEST_SCHEMA}", f"that fits this schema: {sample_json_schema}",
n=3, n=3,
temperature=1.0, temperature=1.0,
max_tokens=500, max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA, extra_body=dict(guided_json=sample_json_schema,
guided_decoding_backend=guided_decoding_backend)) guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None assert completion.id is not None
assert len(completion.choices) == 3 assert len(completion.choices) == 3
for i in range(3): for i in range(3):
output_json = json.loads(completion.choices[i].text) output_json = json.loads(completion.choices[i].text)
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_regex_completion(client: openai.AsyncOpenAI, async def test_guided_regex_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str,
sample_regex):
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}", prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
n=3, n=3,
temperature=1.0, temperature=1.0,
max_tokens=20, max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX, extra_body=dict(guided_regex=sample_regex,
guided_decoding_backend=guided_decoding_backend)) guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None assert completion.id is not None
assert len(completion.choices) == 3 assert len(completion.choices) == 3
for i in range(3): for i in range(3):
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None assert re.fullmatch(sample_regex,
completion.choices[i].text) is not None
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_choice_completion(client: openai.AsyncOpenAI, async def test_guided_choice_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str,
sample_guided_choice):
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ", prompt="The best language for type-safe systems programming is ",
n=2, n=2,
temperature=1.0, temperature=1.0,
max_tokens=10, max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE, extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend)) guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None assert completion.id is not None
assert len(completion.choices) == 2 assert len(completion.choices) == 2
for i in range(2): for i in range(2):
assert completion.choices[i].text in TEST_CHOICE assert completion.choices[i].text in sample_guided_choice
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_guided_grammar(client: openai.AsyncOpenAI): async def test_guided_grammar(client: openai.AsyncOpenAI,
simple_sql_grammar = """ sample_sql_statements):
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
"""
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
@ -607,13 +554,13 @@ number: "1" | "2"
"table_1 where it is equals to 1"), "table_1 where it is equals to 1"),
temperature=1.0, temperature=1.0,
max_tokens=500, max_tokens=500,
extra_body=dict(guided_grammar=simple_sql_grammar)) extra_body=dict(guided_grammar=sample_sql_statements))
content = completion.choices[0].text content = completion.choices[0].text
# use Lark to parse the output, and make sure it's a valid parse tree # use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark from lark import Lark
parser = Lark(simple_sql_grammar) parser = Lark(sample_sql_statements)
parser.parse(content) parser.parse(content)
# remove spaces for comparison b/c we removed them in the grammar # remove spaces for comparison b/c we removed them in the grammar
@ -661,7 +608,8 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
guided_decoding_backend: str): guided_decoding_backend: str,
sample_json_schema, sample_regex):
with pytest.raises(openai.BadRequestError): with pytest.raises(openai.BadRequestError):
_ = await client.completions.create( _ = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
@ -673,7 +621,8 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
_ = await client.completions.create( _ = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
prompt="Give an example string that fits this regex", prompt="Give an example string that fits this regex",
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) extra_body=dict(guided_regex=sample_regex,
guided_json=sample_json_schema))
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -10,59 +10,17 @@ from vllm.model_executor.guided_decoding import (
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor) JSONLogitsProcessor, RegexLogitsProcessor)
TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" def test_guided_logits_processors(sample_regex, sample_json_schema):
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer) regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
json_LP = JSONLogitsProcessor(TEST_SCHEMA, json_LP = JSONLogitsProcessor(sample_json_schema,
tokenizer, tokenizer,
whitespace_pattern=None) whitespace_pattern=None)
token_ids = tokenizer.encode( token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}") f"Give an example IPv4 address with this regex: {sample_regex}")
tensor = torch.rand(32000) tensor = torch.rand(32000)
original_tensor = torch.clone(tensor) original_tensor = torch.clone(tensor)
regex_LP(token_ids, tensor) regex_LP(token_ids, tensor)
@ -70,7 +28,8 @@ def test_guided_logits_processors():
assert not torch.allclose(tensor, original_tensor) assert not torch.allclose(tensor, original_tensor)
token_ids = tokenizer.encode( token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}") f"Give an employee profile that fits this schema: {sample_json_schema}"
)
tensor = torch.rand(32000) tensor = torch.rand(32000)
original_tensor = torch.clone(tensor) original_tensor = torch.clone(tensor)
json_LP(token_ids, tensor) json_LP(token_ids, tensor)
@ -80,13 +39,14 @@ def test_guided_logits_processors():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"]) @pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
async def test_guided_logits_processor_black_box(backend: str): async def test_guided_logits_processor_black_box(backend: str, sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
token_ids = tokenizer.encode( token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}") f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = CompletionRequest(model='test', regex_request = CompletionRequest(model='test',
prompt=token_ids, prompt=token_ids,
guided_regex=TEST_REGEX) guided_regex=sample_regex)
regex_lp = await get_guided_decoding_logits_processor( regex_lp = await get_guided_decoding_logits_processor(
backend, regex_request, tokenizer) backend, regex_request, tokenizer)
assert regex_lp is not None assert regex_lp is not None
@ -97,10 +57,11 @@ async def test_guided_logits_processor_black_box(backend: str):
assert not torch.allclose(tensor, original_tensor) assert not torch.allclose(tensor, original_tensor)
token_ids = tokenizer.encode( token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}") f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = CompletionRequest(model='test', json_request = CompletionRequest(model='test',
prompt=token_ids, prompt=token_ids,
guided_json=TEST_SCHEMA) guided_json=sample_json_schema)
json_lp = await get_guided_decoding_logits_processor( json_lp = await get_guided_decoding_logits_processor(
backend, json_request, tokenizer) backend, json_request, tokenizer)
assert json_lp is not None assert json_lp is not None