[Model] Support multiple images for qwen-vl (#8247)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Alex Brooks 2024-09-12 11:10:54 -06:00 committed by GitHub
parent e56bf27741
commit c6202daeed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 343 additions and 65 deletions

View File

@ -254,7 +254,7 @@ Multimodal Language Models
- -
* - :code:`QWenLMHeadModel` * - :code:`QWenLMHeadModel`
- Qwen-VL - Qwen-VL
- Image\ :sup:`E` - Image\ :sup:`E+`
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
- -
* - :code:`Qwen2VLForConditionalGeneration` * - :code:`Qwen2VLForConditionalGeneration`

View File

@ -19,7 +19,39 @@ IMAGE_URLS = [
] ]
def load_phi3v(question, image_urls: List[str]): def load_qwenvl_chat(question: str, image_urls: List[str]):
model_name = "Qwen/Qwen-VL-Chat"
llm = LLM(
model=model_name,
trust_remote_code=True,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "".join(f"Picture {i}: <img></img>\n"
for i, _ in enumerate(image_urls, start=1))
# This model does not have a chat_template attribute on its tokenizer,
# so we need to explicitly pass it. We use ChatML since it's used in the
# generation utils of the model:
# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
# Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True,
chat_template=chat_template)
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids, None, chat_template
def load_phi3v(question: str, image_urls: List[str]):
llm = LLM( llm = LLM(
model="microsoft/Phi-3.5-vision-instruct", model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True, trust_remote_code=True,
@ -30,10 +62,10 @@ def load_phi3v(question, image_urls: List[str]):
for i, _ in enumerate(image_urls, start=1)) for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids, None return llm, prompt, stop_token_ids, None, None
def load_internvl(question, image_urls: List[str]): def load_internvl(question: str, image_urls: List[str]):
model_name = "OpenGVLab/InternVL2-2B" model_name = "OpenGVLab/InternVL2-2B"
llm = LLM( llm = LLM(
@ -61,7 +93,7 @@ def load_internvl(question, image_urls: List[str]):
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids, None return llm, prompt, stop_token_ids, None, None
def load_qwen2_vl(question, image_urls: List[str]): def load_qwen2_vl(question, image_urls: List[str]):
@ -111,18 +143,19 @@ def load_qwen2_vl(question, image_urls: List[str]):
else: else:
image_data, _ = process_vision_info(messages) image_data, _ = process_vision_info(messages)
return llm, prompt, stop_token_ids, image_data return llm, prompt, stop_token_ids, image_data, None
model_example_map = { model_example_map = {
"phi3_v": load_phi3v, "phi3_v": load_phi3v,
"internvl_chat": load_internvl, "internvl_chat": load_internvl,
"qwen2_vl": load_qwen2_vl, "qwen2_vl": load_qwen2_vl,
"qwen_vl_chat": load_qwenvl_chat,
} }
def run_generate(model, question: str, image_urls: List[str]): def run_generate(model, question: str, image_urls: List[str]):
llm, prompt, stop_token_ids, image_data = model_example_map[model]( llm, prompt, stop_token_ids, image_data, _ = model_example_map[model](
question, image_urls) question, image_urls)
if image_data is None: if image_data is None:
image_data = [fetch_image(url) for url in image_urls] image_data = [fetch_image(url) for url in image_urls]
@ -146,29 +179,32 @@ def run_generate(model, question: str, image_urls: List[str]):
def run_chat(model: str, question: str, image_urls: List[str]): def run_chat(model: str, question: str, image_urls: List[str]):
llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls) llm, _, stop_token_ids, _, chat_template = model_example_map[model](
question, image_urls)
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(temperature=0.0,
max_tokens=128, max_tokens=128,
stop_token_ids=stop_token_ids) stop_token_ids=stop_token_ids)
outputs = llm.chat(
outputs = llm.chat([{ [{
"role": "role":
"user", "user",
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": question, "text": question,
},
*({
"type": "image_url",
"image_url": {
"url": image_url
}, },
} for image_url in image_urls), *({
], "type": "image_url",
}], "image_url": {
sampling_params=sampling_params) "url": image_url
},
} for image_url in image_urls),
],
}],
sampling_params=sampling_params,
chat_template=chat_template,
)
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text

View File

@ -1,11 +1,17 @@
import pathlib import pathlib
from typing import List, Optional, Type from typing import Dict, List, Optional, Tuple, Type, Union
import pytest import pytest
import torch
from PIL.Image import Image
from vllm.multimodal.utils import rescale_image_size from vllm.config import ModelConfig
from vllm.inputs import InputContext, LLMInputs
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ..conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput,
VllmRunner, _ImageAssets)
from .utils import check_logprobs_close from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
@ -23,19 +29,205 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"Picture 1: <img></img>\nWhat is the season?: ", "Picture 1: <img></img>\nWhat is the season?: ",
}) })
HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1: <img></img>\nPicture 2: <img></img>\nCan you compare these images?\n" # noqa: E501
HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1: <img></img>\nPicture 2: <img></img>\nDescribe the two images in detail.\n" # noqa: E501
### Multimodal preprocessing tests
SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image
# These values are specific to Qwen-VL/Chat; we can get these from the model
# config also, but they are hardcoded here to keep the parameterize/fixtures
# easy to read.
IMG_START_ID = 151857
IMG_END_ID = 151858
IMG_PAD_ID = 151859
TOKS_PER_IMG = 256
VIS_ENC_DIM = 4096
IMG_SIZE = 448
def build_model_context(model_name: str,
tokenizer_name: Optional[str] = None,
trust_remote_code: bool = False):
"""Creates an InputContext for a given model.
Args:
model_name: Name of the model being considered.
tokenizer_name: Name of the tokenizer being considered.
trust_remote_code: Whether or not to allow loading remote code.
Returns:
InputContext for the model being considered.
"""
if tokenizer_name is None:
tokenizer_name = model_name
model_config = ModelConfig(
model_name,
tokenizer_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
dtype="float32",
seed=0,
)
return InputContext(model_config)
@pytest.fixture()
def input_mapper_for_qwen():
# Lazy import to avoid initializing CUDA during test collection
from vllm.model_executor.models.qwen import input_mapper_for_qwen
return input_mapper_for_qwen
@pytest.fixture()
def input_processor_for_qwen():
# Lazy import to avoid initializing CUDA during test collection
from vllm.model_executor.models.qwen import input_processor_for_qwen
return input_processor_for_qwen
@pytest.fixture()
def qwen_vl_context() -> InputContext:
"""Get an InputContext for Qwen-VL."""
return build_model_context(model_name="Qwen/Qwen-VL",
trust_remote_code=True)
# Happy path tests for single/multi-image scenarios for the multimodal
# input processor and mapper, respectively
@pytest.mark.parametrize("num_images", [1, 2])
def test_input_processor_valid_mm_data(input_processor_for_qwen,
qwen_vl_context: InputContext,
num_images: int):
"""Happy cases for image inputs to Qwen's multimodal input processor."""
prompt = "".join(
[f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
inputs = LLMInputs(
prompt=prompt,
# When processing multimodal data for a multimodal model, the qwen
# input processor will overwrite the provided prompt_token_ids with
# the image prompts
prompt_token_ids=None,
multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
)
proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
assert isinstance(proc_inputs, dict)
# Each image should have one start / stop and a fixed context of 256
proc_tokens = proc_inputs["prompt_token_ids"]
assert proc_tokens.count(IMG_START_ID) == num_images
assert proc_tokens.count(IMG_END_ID) == num_images
assert proc_tokens.count(IMG_PAD_ID) == num_images * TOKS_PER_IMG
@pytest.mark.parametrize(
"img_data,expected_shape",
[
# single / multi-image
(SAMPLE_IMAGE, (1, 3, IMG_SIZE, IMG_SIZE)),
(2 * [SAMPLE_IMAGE], (2, 3, IMG_SIZE, IMG_SIZE)),
# single / multi-image embeddings
(torch.rand(
(TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
(torch.rand(
(1, TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
(torch.rand(
(2, TOKS_PER_IMG, VIS_ENC_DIM)), (2, TOKS_PER_IMG, VIS_ENC_DIM)),
])
def test_input_mapper_valid_mm_data(input_mapper_for_qwen,
qwen_vl_context: InputContext,
img_data: Union[torch.Tensor, List[Image],
Image],
expected_shape: List[int]):
"""Happy cases for image inputs to Qwen's multimodal input mapper."""
mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data)
# Ensure that we get the appropriately shaped pixel_values
# for images and image embeddings, respectively.
assert isinstance(mapped_img_data, MultiModalInputs)
assert "pixel_values" in mapped_img_data
assert mapped_img_data["pixel_values"].shape == expected_shape
# Sad path tests for the multimodal input processor and mapper, respectively
@pytest.mark.parametrize("mm_data", [
{
"image": torch.rand((5))
},
{
"image": torch.rand((5, 5, 5, 5, 5))
},
])
def test_input_processor_invalid_mm_data(input_processor_for_qwen,
qwen_vl_context: InputContext,
mm_data: Dict[str, torch.Tensor]):
"""Test sad cases validated in Qwen's multimodal input processor."""
tokenizer = cached_get_tokenizer(qwen_vl_context.model_config.tokenizer,
trust_remote_code=True)
prompt = "Picture 1: <img></img>\n"
prompt_token_ids = tokenizer.encode(prompt)
inputs = LLMInputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_data)
# Should fail since we have too many or too few dimensions for embeddings
with pytest.raises(ValueError):
input_processor_for_qwen(qwen_vl_context, inputs)
@pytest.mark.parametrize(
"img_data",
[
# Wrong context length
torch.rand((1, TOKS_PER_IMG + 10, VIS_ENC_DIM)),
# Wrong visual encoder output size
torch.rand((1, TOKS_PER_IMG, VIS_ENC_DIM + 10)),
])
def test_input_mapper_invalid_mm_data(
input_mapper_for_qwen,
qwen_vl_context: InputContext,
img_data: Union[torch.Tensor, List[Image], Image],
):
"""Sad cases validated in Qwen VL's multimodal input mapper."""
with pytest.raises(ValueError):
input_mapper_for_qwen(qwen_vl_context, img_data)
### End-to-end generation tests
def get_prompt_with_path(tmp_path: pathlib.PosixPath, prompt: str,
assets: Union[_ImageAssets, List[ImageAsset]]) -> str:
"""Given a temporary dir path, export one or more image assets into the
tempdir & replace its contents with the local path to the string so that
the HF version of Qwen-VL can resolve the path and load the image ni its
forward() call.
Args:
tmp_path: Tempdir for test under consideration.
prompt: Prompt with image placeholders.
assets: List of image assets whose len equals the num placeholders.
"""
# Ensure that the number of placeholders matches the number of assets;
# If this is not true, the test is probably written incorrectly.
assert prompt.count("<img></img>") == len(assets)
# Replace the placeholders with local paths to the exported assets
for asset in assets:
image_tmp_path = tmp_path / f"{asset.name}.jpg"
asset.pil_image.save(image_tmp_path)
prompt = prompt.replace(
"<img></img>",
f"<img>{image_tmp_path}</img>",
1,
)
return prompt
### Tests for multimodal Qwen models
def run_test( def run_test(
tmp_path: pathlib.PosixPath,
hf_runner: Type[HfRunner], hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner], vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets, inputs: List[Tuple[List[str], PromptImageInput]],
model: str, model: str,
*, *,
size_factors: List[float],
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
mm_limit: int,
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[str] = None,
): ):
@ -48,23 +240,6 @@ def run_test(
Note, the text input is also adjusted to abide by vllm contract. Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf. The text output is sanitized to be able to compare with hf.
""" """
images = [asset.pil_image for asset in image_assets]
# Export the images to a tempdir and substitute it into the hf prompt;
# the contents between <img>/</img> will be ignored by VLLM, but the
# transformers implementation for the visual transformer parses this to
# reload it in the forward call; the contents are treated as a URL or a
# local path.
for idx, asset in enumerate(image_assets):
image_tmp_path = tmp_path / f"{asset.name}.jpg"
asset.pil_image.save(image_tmp_path)
HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace(
"<img></img>", f"<img>{image_tmp_path}</img>")
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF. # NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization. # vLLM needs a fresh new process without cuda initialization.
@ -72,11 +247,12 @@ def run_test(
# will hurt multiprocessing backend with fork method (the default method). # will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size # max_model_len should be greater than image_feature_size
# Qwen encodes images into a fixed content size of 256 # Qwen encodes each image into a fixed content size of 256
with vllm_runner(model, with vllm_runner(model,
max_model_len=300, max_model_len=1024,
max_num_seqs=1, max_num_seqs=1,
dtype=dtype, dtype=dtype,
limit_mm_per_prompt={"image": mm_limit},
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model: enforce_eager=True) as vllm_model:
@ -85,7 +261,7 @@ def run_test(
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=images) images=images)
for prompts, images in inputs_per_image for prompts, images in inputs
] ]
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
@ -94,7 +270,7 @@ def run_test(
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=images) images=images)
for prompts, images in inputs_per_image for prompts, images in inputs
] ]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
@ -125,19 +301,81 @@ def run_test(
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [8]) @pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets, def test_multimodal_models_single_image(tmp_path: pathlib.PosixPath,
model, size_factors, dtype, max_tokens, hf_runner: Type[HfRunner],
num_logprobs) -> None: vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets, model: str,
size_factors: List[float], dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
"""Tests multimodal models with single image prompts."""
images = [asset.pil_image for asset in image_assets]
prompts = [
get_prompt_with_path(tmp_path, prompt, [asset])
for prompt, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
inputs = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, prompts)]
run_test( run_test(
tmp_path,
hf_runner, hf_runner,
vllm_runner, vllm_runner,
image_assets, inputs,
model, model,
size_factors=size_factors,
dtype=dtype, dtype=dtype,
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
mm_limit=1,
tensor_parallel_size=1,
)
@pytest.mark.parametrize("model", multimodal_models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_multimodal_models_multi_image(tmp_path: pathlib.PosixPath,
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets, model: str,
size_factors: List[float], dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
"""Tests multimodal models with multi-image prompts."""
images = [asset.pil_image for asset in image_assets]
# Put all of the images into one prompt.
prompt = get_prompt_with_path(tmp_path, HF_MULTIIMAGE_IMAGE_PROMPT,
image_assets)
inputs = [([prompt for _ in size_factors],
[[rescale_image_size(image, factor) for image in images]
for factor in size_factors])]
run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=2,
tensor_parallel_size=1, tensor_parallel_size=1,
) )
@ -150,7 +388,7 @@ def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets,
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_text_only_qwen_model_can_be_loaded_and_run( def test_text_only_qwen_model_can_be_loaded_and_run(
vllm_runner: Type[VllmRunner], vllm_runner: Type[VllmRunner],
example_prompts, example_prompts: List[str],
model: str, model: str,
*, *,
dtype: str, dtype: str,

View File

@ -47,6 +47,7 @@ from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from vllm.utils import is_list_of
from .utils import flatten_bn, is_pp_missing_parameter, make_layers from .utils import flatten_bn, is_pp_missing_parameter, make_layers
@ -684,9 +685,12 @@ def input_processor_for_qwen(ctx: InputContext,
raise ValueError( raise ValueError(
f"Expected img embeds to be have 3 dimensions, got {num_dims}") f"Expected img embeds to be have 3 dimensions, got {num_dims}")
num_images = 1 if num_dims == 2 else image_data.shape[0] num_images = 1 if num_dims == 2 else image_data.shape[0]
else: elif isinstance(image_data, Image.Image):
# TODO - handle multiple image inputs once the API is solidified
num_images = 1 num_images = 1
elif is_list_of(image_data, Image.Image):
num_images = len(image_data)
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
if prompt is None: if prompt is None:
prompt = tokenizer.decode(prompt_token_ids) prompt = tokenizer.decode(prompt_token_ids)
@ -767,11 +771,11 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but " f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
f"received shape [{data.shape}]") f"received shape [{data.shape}]")
pixel_values = data pixel_values = data
else: else:
transform = build_normalization_transform(image_size) transform = build_normalization_transform(image_size)
# TODO - handle multiple image inputs once the API is solidified if not isinstance(data, (list, tuple)):
transformed_images = [transform(data)] data = [data]
transformed_images = [transform(datum) for datum in data]
pixel_values = torch.stack(transformed_images, dim=0) pixel_values = torch.stack(transformed_images, dim=0)
return MultiModalInputs({"pixel_values": pixel_values}) return MultiModalInputs({"pixel_values": pixel_values})