From ee93f4f92acbd9759a9af80747bc2a4459f07639 Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Wed, 3 Jul 2024 06:25:17 +0800 Subject: [PATCH] [CORE] Quantized lm-head Framework (#4442) Co-authored-by: Robert Shaw Co-authored-by: ZX --- tests/lora/test_layers.py | 10 +-- tests/quantization/test_lm_head.py | 45 ++++++++++++ tests/spec_decode/e2e/test_mlp_correctness.py | 2 +- tests/test_logits_processor.py | 2 +- vllm/lora/layers.py | 4 +- .../model_executor/layers/logits_processor.py | 16 +++-- .../layers/quantization/base_config.py | 9 +++ .../layers/quantization/gptq.py | 13 +++- .../layers/quantization/gptq_marlin.py | 15 ++-- .../layers/quantization/marlin.py | 13 +++- .../layers/vocab_parallel_embedding.py | 70 ++++++++++++++----- vllm/model_executor/models/arctic.py | 3 +- vllm/model_executor/models/baichuan.py | 6 +- vllm/model_executor/models/bloom.py | 4 +- vllm/model_executor/models/chatglm.py | 7 +- vllm/model_executor/models/commandr.py | 8 +-- vllm/model_executor/models/dbrx.py | 3 +- vllm/model_executor/models/deepseek.py | 6 +- vllm/model_executor/models/deepseek_v2.py | 6 +- vllm/model_executor/models/falcon.py | 6 +- vllm/model_executor/models/gemma.py | 4 +- vllm/model_executor/models/gemma2.py | 4 +- vllm/model_executor/models/gpt2.py | 4 +- vllm/model_executor/models/gpt_bigcode.py | 4 +- vllm/model_executor/models/gpt_j.py | 3 +- vllm/model_executor/models/gpt_neox.py | 3 +- vllm/model_executor/models/internlm2.py | 6 +- vllm/model_executor/models/jais.py | 4 +- vllm/model_executor/models/llama.py | 3 +- vllm/model_executor/models/llava.py | 5 +- vllm/model_executor/models/llava_next.py | 5 +- vllm/model_executor/models/minicpm.py | 7 +- vllm/model_executor/models/mixtral.py | 3 +- vllm/model_executor/models/mixtral_quant.py | 6 +- vllm/model_executor/models/mlp_speculator.py | 8 +-- vllm/model_executor/models/mpt.py | 4 +- vllm/model_executor/models/olmo.py | 6 +- vllm/model_executor/models/opt.py | 4 +- vllm/model_executor/models/orion.py | 6 +- vllm/model_executor/models/phi.py | 5 +- vllm/model_executor/models/phi3_small.py | 3 +- vllm/model_executor/models/phi3v.py | 6 +- vllm/model_executor/models/qwen.py | 6 +- vllm/model_executor/models/qwen2.py | 8 +-- vllm/model_executor/models/qwen2_moe.py | 6 +- vllm/model_executor/models/stablelm.py | 6 +- vllm/model_executor/models/starcoder2.py | 6 +- vllm/model_executor/models/xverse.py | 6 +- 48 files changed, 268 insertions(+), 121 deletions(-) create mode 100644 tests/quantization/test_lm_head.py diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 2e51e95a..7207af6b 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -475,10 +475,10 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), - embedding=linear.weight, + lm_head=linear, embedding_bias=None) - original_weight = linear.weight.clone() + original_lm_head = deepcopy(linear) linear.weight[logits_processor. org_vocab_size:logits_processor.org_vocab_size + @@ -490,7 +490,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = logits_processor._get_logits(hidden_states=input_, - embedding=linear.weight, + lm_head=linear, embedding_bias=None) result[:, vocab_size + embeddings_tensor_len:] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling @@ -519,11 +519,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), - embedding=original_weight, + lm_head=original_lm_head, embedding_bias=None)[:, :vocab_size] expected_result = logits_processor._get_logits( hidden_states=torch.cat(inputs), - embedding=original_weight, + lm_head=original_lm_head, embedding_bias=None) rtol, atol = TOLERANCES[lora_result.dtype] diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py new file mode 100644 index 00000000..dd9a0168 --- /dev/null +++ b/tests/quantization/test_lm_head.py @@ -0,0 +1,45 @@ +"""Tests whether gptq models with quantized lm_head can be loaded. + +Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`. +""" +from typing import Tuple + +import pytest +import torch + +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinLinearMethod) +from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod + +PROMPT = "On the surface of Mars, we found" + +MODELS_QUANT = [( + "LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse", + True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False), + ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)] + + +@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT) +def test_lm_head( + vllm_runner, + model_lm_head_quant: Tuple[str, bool], +) -> None: + model, lm_head_quantized = model_lm_head_quant + vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048) + + lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model.lm_head) + + if lm_head_quantized: + assert isinstance( + lm_head_layer.linear_method, + (GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod)) + else: + assert isinstance(lm_head_layer.linear_method, UnquantizedLinearMethod) + + print( + vllm_model.generate_greedy(prompts=["Hello my name is"], + max_tokens=10)[0][1]) + del vllm_model diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 9a9f2acb..dd67a773 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -34,7 +34,7 @@ SPEC_MODEL = "ibm-granite/granite-3b-code-instruct-accelerator" MAX_SPEC_TOKENS = 5 # precision -PRECISION = "float16" +PRECISION = "float32" @pytest.mark.parametrize( diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 4ee98050..8ee2d781 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -83,7 +83,7 @@ def test_logits_processors(seed: int, device: str): device=device, pin_memory=is_pin_memory_available()) logits_processor_output = logits_processor( - embedding=None, + lm_head=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 2fddfcca..0a63f9ef 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1172,11 +1172,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): def _get_logits( self, hidden_states: torch.Tensor, - embedding: torch.Tensor, + lm_head: VocabParallelEmbedding, embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) + logits = lm_head.linear_method.apply(lm_head, hidden_states) if embedding_bias is not None: logits += embedding_bias logits = tensor_model_parallel_gather(logits) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 8062bfb5..f6fcf49e 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -6,6 +6,8 @@ import torch import torch.nn as nn from vllm.distributed import tensor_model_parallel_gather +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -40,7 +42,7 @@ class LogitsProcessor(nn.Module): def forward( self, - embedding: torch.Tensor, + lm_head: VocabParallelEmbedding, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, embedding_bias: Optional[torch.Tensor] = None, @@ -52,8 +54,7 @@ class LogitsProcessor(nn.Module): sampling_metadata) # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, embedding, embedding_bias) - + logits = self._get_logits(hidden_states, lm_head, embedding_bias) if logits is not None: if self.soft_cap is not None: logits = logits / self.soft_cap @@ -68,12 +69,13 @@ class LogitsProcessor(nn.Module): return logits - def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + def _get_logits(self, hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias + logits = lm_head.linear_method.apply(lm_head, + hidden_states, + bias=embedding_bias) logits = tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index c23b6616..1607470c 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -87,6 +87,15 @@ class QuantizationConfig(ABC): raise ValueError(f"Cannot find any of {keys} in the model's " "quantization config.") + @staticmethod + def get_from_keys_or(config: Dict[str, Any], keys: List[str], + default: Any) -> Any: + """Get a optional value from the model's quantization config.""" + try: + return QuantizationConfig.get_from_keys(config, keys) + except ValueError: + return default + @abstractmethod def get_quant_method( self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index ae9f7019..595d6ab9 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -10,6 +10,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.utils import set_weight_attrs @@ -24,10 +25,12 @@ class GPTQConfig(QuantizationConfig): weight_bits: int, group_size: int, desc_act: bool, + lm_head_quantized: bool, ) -> None: self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized self.pack_factor = Fraction(32, self.weight_bits) if self.weight_bits not in [2, 3, 4, 8]: raise ValueError( @@ -37,7 +40,8 @@ class GPTQConfig(QuantizationConfig): def __repr__(self) -> str: return (f"GPTQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " - f"desc_act={self.desc_act})") + f"desc_act={self.desc_act})," + f"lm_head_quantized={self.lm_head_quantized}") @classmethod def get_name(cls) -> str: @@ -61,11 +65,14 @@ class GPTQConfig(QuantizationConfig): weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) - return cls(weight_bits, group_size, desc_act) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, lm_head_quantized) def get_quant_method( self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]: - if isinstance(layer, LinearBase): + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return GPTQLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index c6e9279c..97aae33f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.utils import get_device_capability_stateless logger = init_logger(__name__) @@ -59,7 +60,7 @@ class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool) -> None: + is_sym: bool, lm_head_quantized: bool) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -69,6 +70,7 @@ class GPTQMarlinConfig(QuantizationConfig): self.group_size = group_size self.desc_act = desc_act self.is_sym = is_sym + self.lm_head_quantized = lm_head_quantized # Verify if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS: @@ -96,7 +98,8 @@ class GPTQMarlinConfig(QuantizationConfig): def __repr__(self) -> str: return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " - f"desc_act={self.desc_act})") + f"desc_act={self.desc_act}, " + f"lm_head_quantized={self.lm_head_quantized})") @classmethod def get_name(cls) -> str: @@ -120,7 +123,10 @@ class GPTQMarlinConfig(QuantizationConfig): group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) is_sym = cls.get_from_keys(config, ["sym"]) - return cls(weight_bits, group_size, desc_act, is_sym) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, + lm_head_quantized) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -145,7 +151,8 @@ class GPTQMarlinConfig(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: - if isinstance(layer, LinearBase): + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return GPTQMarlinLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 3613c9d9..f0a9cf55 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -22,9 +23,11 @@ class MarlinConfig(QuantizationConfig): def __init__( self, group_size: int, + lm_head_quantized: bool, ) -> None: # Group size for the quantization. self.group_size = group_size + self.lm_head_quantized = lm_head_quantized if self.group_size != 128 and self.group_size != -1: raise ValueError( "Currently, only group size 128 and -1 (channelwise) " @@ -51,7 +54,8 @@ class MarlinConfig(QuantizationConfig): self.perm_len = 1024 def __repr__(self) -> str: - return f"MarlinConfig(group_size={self.group_size})" + return (f"MarlinConfig(group_size={self.group_size}, " + f"lm_head_quantized={self.lm_head_quantized})") @classmethod def get_name(cls) -> str: @@ -73,7 +77,9 @@ class MarlinConfig(QuantizationConfig): @classmethod def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) - return cls(group_size) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(group_size, lm_head_quantized) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -96,7 +102,8 @@ class MarlinConfig(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: - if isinstance(layer, LinearBase): + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return MarlinLinearMethod(self) return None diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 4650b2c2..d70eb1c2 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -8,6 +8,9 @@ from torch.nn.parameter import Parameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -157,6 +160,7 @@ class VocabParallelEmbedding(torch.nn.Module): params_dtype: type of the parameters. org_num_embeddings: original vocabulary size (without LoRA). padding_size: padding size for the vocabulary. + quant_config: quant config for the layer """ # noqa: E501 def __init__(self, @@ -164,7 +168,8 @@ class VocabParallelEmbedding(torch.nn.Module): embedding_dim: int, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None): super().__init__() # Keep the input dimensions. @@ -187,6 +192,14 @@ class VocabParallelEmbedding(torch.nn.Module): self.org_vocab_size, tp_rank, self.tp_size) self.embedding_dim = embedding_dim + + linear_method = None + if quant_config is not None: + linear_method = quant_config.get_quant_method(self) + if linear_method is None: + linear_method = UnquantizedLinearMethod() + self.linear_method: QuantizeMethodBase = linear_method + if params_dtype is None: params_dtype = torch.get_default_dtype() # Divide the weight matrix along the vocaburaly dimension. @@ -201,14 +214,14 @@ class VocabParallelEmbedding(torch.nn.Module): self.num_added_embeddings_per_partition = ( self.shard_indices.added_vocab_end_index - self.shard_indices.added_vocab_start_index) - self.weight = Parameter( - torch.empty(self.num_embeddings_per_partition, - self.embedding_dim, - dtype=params_dtype)) - set_weight_attrs(self.weight, { - "parallel_dim": 0, - "weight_loader": self.weight_loader - }) + + self.linear_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) @classmethod def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, @@ -288,10 +301,32 @@ class VocabParallelEmbedding(torch.nn.Module): return ret def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - parallel_dim = param.parallel_dim - assert loaded_weight.shape[parallel_dim] == self.org_vocab_size - loaded_weight = loaded_weight[self.shard_indices.org_vocab_start_index: - self.shard_indices.org_vocab_end_index] + output_dim = getattr(param, "output_dim", None) + packed_dim = getattr(param, "packed_dim", None) + + # If parameter does not have output dim, then it should + # be copied onto all gpus (e.g. g_idx for act_order gptq). + if output_dim is None: + assert param.data.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + return + + # Shard indexes for loading the weight + start_idx = self.shard_indices.org_vocab_start_index + shard_size = self.shard_indices.org_vocab_end_index - start_idx + + # If param packed on the same dim we are sharding on, then + # need to adjust offsets of loaded weight by pack_factor. + if packed_dim is not None and packed_dim == output_dim: + assert loaded_weight.shape[output_dim] == (self.org_vocab_size // + param.pack_factor) + start_idx = start_idx // param.pack_factor + shard_size = shard_size // param.pack_factor + else: + assert loaded_weight.shape[output_dim] == self.org_vocab_size + + # Copy the data. + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) param[:loaded_weight.shape[0]].data.copy_(loaded_weight) param[loaded_weight.shape[0]:].data.fill_(0) @@ -346,16 +381,17 @@ class ParallelLMHead(VocabParallelEmbedding): bias: bool = False, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None): super().__init__(num_embeddings, embedding_dim, params_dtype, - org_num_embeddings, padding_size) + org_num_embeddings, padding_size, quant_config) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)) set_weight_attrs(self.bias, { - "parallel_dim": 0, - "weight_loader": self.weight_loader + "output_dim": 0, + "weight_loader": self.weight_loader, }) else: self.register_parameter("bias", None) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index fec52e01..49e57a84 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -412,6 +412,7 @@ class ArcticForCausalLM(nn.Module): self.lm_head = ParallelLMHead( self.vocab_size, config.hidden_size, + quant_config=quant_config, ) self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok @@ -434,7 +435,7 @@ class ArcticForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index ddc4e908..e1ea8bfc 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -328,7 +328,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): self.quant_config = quant_config self.model = BaiChuanModel(config, position_embedding, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -346,7 +348,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 8387c8e3..86ae32e0 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -276,7 +276,7 @@ class BloomForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.transformer = BloomModel(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.word_embeddings.weight + self.lm_head = self.transformer.word_embeddings self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -294,7 +294,7 @@ class BloomForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index e6012a6d..553ddf90 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -303,7 +303,8 @@ class ChatGLMModel(nn.Module): self.encoder = GLMTransformer(config, cache_config, quant_config) self.output_layer = ParallelLMHead(config.padded_vocab_size, - config.hidden_size) + config.hidden_size, + quant_config=quant_config) def forward( self, @@ -355,7 +356,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) self.transformer = ChatGLMModel(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.output_layer.weight + self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() @@ -373,7 +374,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 2961f421..5f6e3a13 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -363,12 +363,12 @@ class CohereForCausalLM(nn.Module): sampling_metadata: SamplingMetadata) -> torch.Tensor: is_not_lora = hasattr(self.model.embed_tokens, 'weight') if is_not_lora: - embedding_weights = self.model.embed_tokens.weight + logits = self.logits_processor(self.model.embed_tokens, + hidden_states, sampling_metadata) else: - embedding_weights = self.model.embed_tokens.base_layer.weight + logits = self.logits_processor(self.model.embed_tokens.base_layer, + hidden_states, sampling_metadata) - logits = self.logits_processor(embedding_weights, hidden_states, - sampling_metadata) return logits def sample( diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 210cf616..d758333b 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -370,6 +370,7 @@ class DbrxForCausalLM(nn.Module): config.d_model, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -389,7 +390,7 @@ class DbrxForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index e9ceca9b..3fd6f221 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -377,7 +377,9 @@ class DeepseekForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.model = DeepseekModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -395,7 +397,7 @@ class DeepseekForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 3cf62afd..fb4097fd 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -465,7 +465,9 @@ class DeepseekV2ForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.model = DeepseekV2Model(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -483,7 +485,7 @@ class DeepseekV2ForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 89b0bbf0..93f07327 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -394,13 +394,13 @@ class FalconForCausalLM(nn.Module): if config.tie_word_embeddings is not None else True) if self.tie_word_embeddings: - self.lm_head_weight = self.transformer.word_embeddings.weight + self.lm_head = self.transformer.word_embeddings else: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, + quant_config=quant_config, ) - self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -422,7 +422,7 @@ class FalconForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 0a5a7ed3..b603a591 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -347,8 +347,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.model.embed_tokens.weight, - hidden_states, sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) return logits def sample( diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 1f921c8b..8fedff62 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -346,8 +346,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.model.embed_tokens.weight, - hidden_states, sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) return logits def sample( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 55f2e274..be19f4ba 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -238,7 +238,7 @@ class GPT2LMHeadModel(nn.Module): self.config = config self.quant_config = quant_config self.transformer = GPT2Model(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.wte.weight + self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -256,7 +256,7 @@ class GPT2LMHeadModel(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 7d0bf39c..cc42413d 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -259,7 +259,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): self.quant_config = quant_config self.transformer = GPTBigCodeModel(config, cache_config, quant_config, lora_config) - self.lm_head_weight = self.transformer.wte.weight + self.lm_head = self.transformer.wte self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -281,7 +281,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index de7f86af..4bb9debe 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -229,6 +229,7 @@ class GPTJForCausalLM(nn.Module): config.vocab_size, config.n_embd, bias=True, + quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -247,7 +248,7 @@ class GPTJForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata, self.lm_head.bias) return logits diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 3658b8fb..b306574b 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -241,6 +241,7 @@ class GPTNeoXForCausalLM(nn.Module): self.embed_out = ParallelLMHead( config.vocab_size, config.hidden_size, + quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -259,7 +260,7 @@ class GPTNeoXForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.embed_out.weight, hidden_states, + logits = self.logits_processor(self.embed_out, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 283bc064..22132f40 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -253,7 +253,9 @@ class InternLM2ForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.model = InternLM2Model(config, cache_config, quant_config) - self.output = ParallelLMHead(config.vocab_size, config.hidden_size) + self.output = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -271,7 +273,7 @@ class InternLM2ForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.output.weight, hidden_states, + logits = self.logits_processor(self.output, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 2758e2d0..0030c761 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -273,7 +273,7 @@ class JAISLMHeadModel(nn.Module): self.config = config self.quant_config = quant_config self.transformer = JAISModel(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.wte.weight + self.lm_head = self.transformer.wte if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale else: @@ -297,7 +297,7 @@ class JAISLMHeadModel(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index af75b6be..77edcd74 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -380,6 +380,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight @@ -403,7 +404,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 39c47ddd..bbec4dbd 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -125,7 +125,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.text_config.hidden_size, - org_num_embeddings=self.language_model.org_vocab_size) + org_num_embeddings=self.language_model.org_vocab_size, + quant_config=quant_config) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) @@ -255,7 +256,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 8b078391..f67598c4 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -186,7 +186,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.text_config.hidden_size, - org_num_embeddings=self.language_model.org_vocab_size) + org_num_embeddings=self.language_model.org_vocab_size, + quant_config=quant_config) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) @@ -438,7 +439,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 33020432..4ccf1cf0 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -449,6 +449,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, ) self.scale_width = self.config.hidden_size / self.config.dim_model_base @@ -472,10 +473,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): sampling_metadata: SamplingMetadata) -> torch.Tensor: hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: - lm_head_weight = self.model.embed_tokens.weight + lm_head = self.model.embed_tokens else: - lm_head_weight = self.lm_head.weight - logits = self.logits_processor(lm_head_weight, hidden_states, + lm_head = self.lm_head + logits = self.logits_processor(lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 5144e7ea..7f5e3b96 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -331,6 +331,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -350,7 +351,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index dde2da20..10faa5cc 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -344,7 +344,9 @@ class MixtralForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.model = MixtralModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -362,7 +364,7 @@ class MixtralForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 290a703a..97f7ec74 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -8,7 +8,7 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import MLPSpeculatorConfig @@ -87,7 +87,7 @@ class MLPSpeculator(nn.Module): self.proj = nn.ModuleList([proj_first] + [proj_tied] * (self.max_speculative_tokens - 1)) - head = nn.Linear(self.inner_dim, self.vocab_size, bias=False) + head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) self.head = nn.ModuleList([head] * self.max_speculative_tokens) ln = MLPSpeculatorLayerNorm(self.inner_dim, @@ -169,8 +169,8 @@ class MLPSpeculator(nn.Module): # TODO: not yet supporting top_k_tokens_per_head previous_hidden_states = states - logits = self.logits_processor(self.head[head_index].weight, - states, sampling_metadata) + logits = self.logits_processor(self.head[head_index], states, + sampling_metadata) output = self.sampler(logits.flatten(0, 1), sampling_metadata) last_tokens = output.sampled_token_ids diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 28dc5922..7d658b39 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -263,7 +263,7 @@ class MPTForCausalLM(nn.Module): self.quant_config = quant_config self.transformer = MPTModel(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.wte.weight + self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -281,7 +281,7 @@ class MPTForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 53215f32..408c0c88 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -283,15 +283,15 @@ class OlmoForCausalLM(nn.Module): self.config = config self.model = OlmoModel(config, cache_config, quant_config) if config.tie_word_embeddings: - self.lm_head_weight = self.model.embed_tokens.weight + self.lm_head = self.model.embed_tokens else: self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, + quant_config=quant_config, ) - self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -313,7 +313,7 @@ class OlmoForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index d12a51af..edc16710 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -294,7 +294,7 @@ class OPTForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.model = OPTModel(config, cache_config, quant_config) - self.lm_head_weight = self.model.decoder.embed_tokens.weight + self.lm_head = self.model.decoder.embed_tokens self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -312,7 +312,7 @@ class OPTForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index a298f030..8159cc13 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -259,7 +259,9 @@ class OrionForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.model = OrionModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -277,7 +279,7 @@ class OrionForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index cc8e31fe..ac7496f6 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -268,7 +268,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA): self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - bias=True) + bias=True, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -287,7 +288,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata, self.lm_head.bias) return logits diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 706ae652..cc06929f 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -366,6 +366,7 @@ class Phi3SmallForCausalLM(nn.Module): config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -400,7 +401,7 @@ class Phi3SmallForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) if self.dummy_token_indices is not None and logits is not None: logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index eff4e502..d73a4202 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -365,7 +365,9 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): self.model = LlamaModel(config, cache_config, quant_config) self.vision_embed_tokens = Phi3HDImageEmbedding( vlm_config, config, self.model.embed_tokens) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -409,7 +411,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 408c206c..47c85c78 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -235,7 +235,9 @@ class QWenLMHeadModel(nn.Module): self.config = config self.quant_config = quant_config self.transformer = QWenModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -253,7 +255,7 @@ class QWenLMHeadModel(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 3691a3d2..e9ae2192 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -316,11 +316,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): self.model = Qwen2Model(config, cache_config, quant_config) if config.tie_word_embeddings: - self.lm_head_weight = self.model.embed_tokens.weight + self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size) - self.lm_head_weight = self.lm_head.weight + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -339,7 +339,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 8decb446..ccaa6f20 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -362,7 +362,9 @@ class Qwen2MoeForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.model = Qwen2MoeModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -380,7 +382,7 @@ class Qwen2MoeForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 1098b303..5451b56e 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -240,7 +240,9 @@ class StablelmForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.model = StableLMEpochModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -258,7 +260,7 @@ class StablelmForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 6f3d5d51..1752bfd4 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -242,7 +242,7 @@ class Starcoder2ForCausalLM(nn.Module): self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: - self.lm_head_weight = self.model.embed_tokens.weight + self.lm_head = self.model.embed_tokens else: self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( @@ -250,8 +250,8 @@ class Starcoder2ForCausalLM(nn.Module): config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, ) - self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() @@ -270,7 +270,7 @@ class Starcoder2ForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 08d3efd3..84f0ffc3 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -310,7 +310,9 @@ class XverseForCausalLM(nn.Module, SupportsLoRA): self.quant_config = quant_config self.model = XverseModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -328,7 +330,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits