From a3e8a05d4c1b79dd44eb92bb6f57eb40c3fbdb21 Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Mon, 17 Jun 2024 15:26:41 -0700 Subject: [PATCH] [Bugfix] Fix KV head calculation for MPT models when using GQA (#5142) --- vllm/config.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index b1a3a82f..d95faf52 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -302,7 +302,11 @@ class ModelConfig: return 1 # For DBRX and MPT - if self.hf_config.model_type in ["dbrx", "mpt"]: + if self.hf_config.model_type == "mpt": + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type == "dbrx": return getattr(self.hf_config.attn_config, "kv_n_heads", self.hf_config.num_attention_heads)