diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 82c8dfa7..e7580f21 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -14,6 +14,7 @@ from vllm.model_executor.weight_utils import (get_quant_config, # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { "AquilaModel": AquilaForCausalLM, + "AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2 "BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b "BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b "BloomForCausalLM": BloomForCausalLM, diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 33280d9a..3fece0e2 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -147,6 +147,7 @@ class AquilaAttention(nn.Module): rotary_dim=self.head_dim, base=self.rope_theta, max_position=self.max_position_embeddings, + num_kv_heads=self.num_kv_heads, ) def forward( @@ -177,7 +178,7 @@ class AquilaDecoderLayer(nn.Module): self.self_attn = AquilaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, ) @@ -308,7 +309,7 @@ class AquilaForCausalLM(nn.Module): q_proj_shard_size = (self.config.hidden_size // tp_size) kv_proj_shard_size = (self.config.hidden_size // self.config.num_attention_heads * - self.config.num_attention_heads // tp_size) + self.config.num_key_value_heads // tp_size) attention_weight_specs = [ # (weight_name, shard_size, offset) ("q_proj", q_proj_shard_size, 0), diff --git a/vllm/transformers_utils/configs/aquila.py b/vllm/transformers_utils/configs/aquila.py index 944e8f0e..86a6f2ba 100644 --- a/vllm/transformers_utils/configs/aquila.py +++ b/vllm/transformers_utils/configs/aquila.py @@ -33,6 +33,7 @@ class AquilaConfig(PretrainedConfig): intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, + num_key_value_heads=None, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.006, @@ -49,6 +50,11 @@ class AquilaConfig(PretrainedConfig): self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.initializer_range = initializer_range