[CI/Build] Update pixtral tests to use JSON (#8436)

This commit is contained in:
Cyrus Leung 2024-09-13 11:47:52 +08:00 committed by GitHub
parent 3f79bc3d1a
commit 8427550488
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 42 additions and 18 deletions

View File

@ -76,7 +76,7 @@ exclude = [
[tool.codespell] [tool.codespell]
ignore-words-list = "dout, te, indicies, subtile" ignore-words-list = "dout, te, indicies, subtile"
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
[tool.isort] [tool.isort]
use_parentheses = true use_parentheses = true

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -2,9 +2,10 @@
Run `pytest tests/models/test_mistral.py`. Run `pytest tests/models/test_mistral.py`.
""" """
import pickle import json
import uuid import uuid
from typing import Any, Dict, List from dataclasses import asdict
from typing import Any, Dict, List, Optional, Tuple
import pytest import pytest
from mistral_common.protocol.instruct.messages import ImageURLChunk from mistral_common.protocol.instruct.messages import ImageURLChunk
@ -14,6 +15,7 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
from vllm.multimodal import MultiModalDataBuiltins from vllm.multimodal import MultiModalDataBuiltins
from vllm.sequence import Logprob, SampleLogprobs
from .utils import check_logprobs_close from .utils import check_logprobs_close
@ -81,13 +83,33 @@ SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
LIMIT_MM_PER_PROMPT = dict(image=4) LIMIT_MM_PER_PROMPT = dict(image=4)
MAX_MODEL_LEN = [8192, 65536] MAX_MODEL_LEN = [8192, 65536]
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle" FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.json"
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle" FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.json"
OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]]
def load_logprobs(filename: str) -> Any: # For the test author to store golden output in JSON
with open(filename, 'rb') as f: def _dump_outputs_w_logprobs(outputs: OutputsLogprobs, filename: str) -> None:
return pickle.load(f) json_data = [(tokens, text,
[{k: asdict(v)
for k, v in token_logprobs.items()}
for token_logprobs in (logprobs or [])])
for tokens, text, logprobs in outputs]
with open(filename, "w") as f:
json.dump(json_data, f)
def load_outputs_w_logprobs(filename: str) -> OutputsLogprobs:
with open(filename, "rb") as f:
json_data = json.load(f)
return [(tokens, text,
[{int(k): Logprob(**v)
for k, v in token_logprobs.items()}
for token_logprobs in logprobs])
for tokens, text, logprobs in json_data]
@pytest.mark.skip( @pytest.mark.skip(
@ -103,7 +125,7 @@ def test_chat(
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT) EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
@ -120,10 +142,10 @@ def test_chat(
outputs.extend(output) outputs.extend(output)
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=logprobs, check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
outputs_1_lst=EXPECTED_CHAT_LOGPROBS, outputs_1_lst=logprobs,
name_0="output", name_0="h100_ref",
name_1="h100_ref") name_1="output")
@pytest.mark.skip( @pytest.mark.skip(
@ -133,7 +155,7 @@ def test_chat(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_engine(vllm_runner, model: str, dtype: str) -> None: def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE) EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
args = EngineArgs( args = EngineArgs(
model=model, model=model,
tokenizer_mode="mistral", tokenizer_mode="mistral",
@ -162,7 +184,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
break break
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=logprobs, check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS, outputs_1_lst=logprobs,
name_0="output", name_0="h100_ref",
name_1="h100_ref") name_1="output")