diff --git a/vllm/config.py b/vllm/config.py index b1a3a82f..d95faf52 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -302,7 +302,11 @@ class ModelConfig: return 1 # For DBRX and MPT - if self.hf_config.model_type in ["dbrx", "mpt"]: + if self.hf_config.model_type == "mpt": + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type == "dbrx": return getattr(self.hf_config.attn_config, "kv_n_heads", self.hf_config.num_attention_heads)