diff --git a/tests/models/test_internvl.py b/tests/models/test_internvl.py index 6007a897..d032f3be 100644 --- a/tests/models/test_internvl.py +++ b/tests/models/test_internvl.py @@ -1,5 +1,5 @@ import types -from typing import List, Optional, Type +from typing import List, Optional, Tuple, Type import pytest import torch @@ -178,6 +178,74 @@ def run_test( ) +def run_awq_test( + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + models: Tuple[str, str], + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + source_model, quant_model = models + + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + with vllm_runner(source_model, + max_model_len=4096, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + source_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + with vllm_runner(quant_model, + quantization="awq", + max_model_len=4096, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + quant_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + for source_outputs, quant_outputs in zip(source_outputs_per_image, + quant_outputs_per_image): + # TODO: Check whether using original CLIPVisionModel can improve + # consistency against HF + check_logprobs_close( + outputs_0_lst=source_outputs, + outputs_1_lst=quant_outputs, + name_0="source", + name_1="awq", + ) + + target_dtype = "half" if is_cpu(): target_dtype = "bfloat16" @@ -214,3 +282,36 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + + +@pytest.mark.parametrize( + "models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")]) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +@torch.inference_mode() +def test_awq_models(vllm_runner, image_assets, models, size_factors, + dtype: str, max_tokens: int, num_logprobs: int) -> None: + run_awq_test( + vllm_runner, + image_assets, + models, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3824ed35..50a50d98 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -570,7 +570,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # for the packing. if isinstance(param, PackedvLLMParameter ) and param.packed_dim == param.output_dim: - param.adjust_shard_indexes_for_packing( + shard_size, shard_offset = \ + param.adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset) loaded_weight_shard = loaded_weight.narrow(param.output_dim, @@ -719,7 +720,8 @@ class QKVParallelLinear(ColumnParallelLinear): # for the packing. if isinstance(param, PackedvLLMParameter ) and param.packed_dim == param.output_dim: - param.adjust_shard_indexes_for_packing( + shard_size, shard_offset = \ + param.adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset) loaded_weight_shard = loaded_weight.narrow(param.output_dim, diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index ae5a4a4d..06664577 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -131,6 +131,10 @@ def get_quant_config(model_config: ModelConfig, # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) + # some vision model may keep quantization_config in their text_config + hf_text_config = getattr(model_config.hf_config, "text_config", None) + if hf_quant_config is None and hf_text_config is not None: + hf_quant_config = getattr(hf_text_config, "quantization_config", None) if hf_quant_config is None: # compressed-tensors uses a compressions_config hf_quant_config = getattr(model_config.hf_config, "compression_config", diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 887a353d..499cdb43 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -87,6 +87,7 @@ class InternLM2Attention(nn.Module): self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim + self.key_value_groups = int(self.num_heads / self.num_kv_heads) self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -120,6 +121,14 @@ class InternLM2Attention(nn.Module): cache_config=cache_config, quant_config=quant_config) + def split_qkv(self, qkv: torch.Tensor): + qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128) + q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2) + q = q.reshape(-1, self.q_size) + k = k.reshape(-1, self.kv_size) + v = v.reshape(-1, self.kv_size) + return q, k, v + def forward( self, positions: torch.Tensor, @@ -128,7 +137,7 @@ class InternLM2Attention(nn.Module): attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.wqkv(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k, v = self.split_qkv(qkv) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.wo(attn_output) @@ -324,24 +333,6 @@ class InternLM2ForCausalLM(nn.Module): if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - if "wqkv" in name: - config = self.config - kv_groups = (config.num_attention_heads // - config.num_key_value_heads) - head_dim = config.hidden_size // config.num_attention_heads - loaded_weight = loaded_weight.view(-1, 2 + kv_groups, - head_dim, - loaded_weight.shape[-1]) - wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], - dim=1) - wq = wq.reshape(-1, wq.shape[-1]) - wk = wk.reshape(-1, wk.shape[-1]) - wv = wv.reshape(-1, wv.shape[-1]) - weight_loader = param.weight_loader - weight_loader(param, wq, 'q') - weight_loader(param, wk, 'k') - weight_loader(param, wv, 'v') - else: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)