[CI/Build] Update pixtral tests to use JSON (#8436)
This commit is contained in:
parent
3f79bc3d1a
commit
8427550488
@ -76,7 +76,7 @@ exclude = [
|
|||||||
|
|
||||||
[tool.codespell]
|
[tool.codespell]
|
||||||
ignore-words-list = "dout, te, indicies, subtile"
|
ignore-words-list = "dout, te, indicies, subtile"
|
||||||
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
use_parentheses = true
|
use_parentheses = true
|
||||||
|
|||||||
1
tests/models/fixtures/pixtral_chat.json
Normal file
1
tests/models/fixtures/pixtral_chat.json
Normal file
File diff suppressed because one or more lines are too long
Binary file not shown.
1
tests/models/fixtures/pixtral_chat_engine.json
Normal file
1
tests/models/fixtures/pixtral_chat_engine.json
Normal file
File diff suppressed because one or more lines are too long
Binary file not shown.
@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
Run `pytest tests/models/test_mistral.py`.
|
Run `pytest tests/models/test_mistral.py`.
|
||||||
"""
|
"""
|
||||||
import pickle
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List
|
from dataclasses import asdict
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from mistral_common.protocol.instruct.messages import ImageURLChunk
|
from mistral_common.protocol.instruct.messages import ImageURLChunk
|
||||||
@ -14,6 +15,7 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
|||||||
|
|
||||||
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
|
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
|
||||||
from vllm.multimodal import MultiModalDataBuiltins
|
from vllm.multimodal import MultiModalDataBuiltins
|
||||||
|
from vllm.sequence import Logprob, SampleLogprobs
|
||||||
|
|
||||||
from .utils import check_logprobs_close
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
@ -81,13 +83,33 @@ SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
|
|||||||
LIMIT_MM_PER_PROMPT = dict(image=4)
|
LIMIT_MM_PER_PROMPT = dict(image=4)
|
||||||
|
|
||||||
MAX_MODEL_LEN = [8192, 65536]
|
MAX_MODEL_LEN = [8192, 65536]
|
||||||
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle"
|
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.json"
|
||||||
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle"
|
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.json"
|
||||||
|
|
||||||
|
OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]]
|
||||||
|
|
||||||
|
|
||||||
def load_logprobs(filename: str) -> Any:
|
# For the test author to store golden output in JSON
|
||||||
with open(filename, 'rb') as f:
|
def _dump_outputs_w_logprobs(outputs: OutputsLogprobs, filename: str) -> None:
|
||||||
return pickle.load(f)
|
json_data = [(tokens, text,
|
||||||
|
[{k: asdict(v)
|
||||||
|
for k, v in token_logprobs.items()}
|
||||||
|
for token_logprobs in (logprobs or [])])
|
||||||
|
for tokens, text, logprobs in outputs]
|
||||||
|
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(json_data, f)
|
||||||
|
|
||||||
|
|
||||||
|
def load_outputs_w_logprobs(filename: str) -> OutputsLogprobs:
|
||||||
|
with open(filename, "rb") as f:
|
||||||
|
json_data = json.load(f)
|
||||||
|
|
||||||
|
return [(tokens, text,
|
||||||
|
[{int(k): Logprob(**v)
|
||||||
|
for k, v in token_logprobs.items()}
|
||||||
|
for token_logprobs in logprobs])
|
||||||
|
for tokens, text, logprobs in json_data]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
@ -103,7 +125,7 @@ def test_chat(
|
|||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT)
|
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -120,10 +142,10 @@ def test_chat(
|
|||||||
outputs.extend(output)
|
outputs.extend(output)
|
||||||
|
|
||||||
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
||||||
check_logprobs_close(outputs_0_lst=logprobs,
|
check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
|
||||||
outputs_1_lst=EXPECTED_CHAT_LOGPROBS,
|
outputs_1_lst=logprobs,
|
||||||
name_0="output",
|
name_0="h100_ref",
|
||||||
name_1="h100_ref")
|
name_1="output")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
@ -133,7 +155,7 @@ def test_chat(
|
|||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
||||||
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE)
|
EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
|
||||||
args = EngineArgs(
|
args = EngineArgs(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer_mode="mistral",
|
tokenizer_mode="mistral",
|
||||||
@ -162,7 +184,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
|||||||
break
|
break
|
||||||
|
|
||||||
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
||||||
check_logprobs_close(outputs_0_lst=logprobs,
|
check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
|
||||||
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS,
|
outputs_1_lst=logprobs,
|
||||||
name_0="output",
|
name_0="h100_ref",
|
||||||
name_1="h100_ref")
|
name_1="output")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user