[Bugfix] Fix KV head calculation for MPT models when using GQA (#5142)
This commit is contained in:
parent
e441bad674
commit
a3e8a05d4c
@ -302,7 +302,11 @@ class ModelConfig:
|
||||
return 1
|
||||
|
||||
# For DBRX and MPT
|
||||
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
||||
if self.hf_config.model_type == "mpt":
|
||||
if "kv_n_heads" in self.hf_config.attn_config:
|
||||
return self.hf_config.attn_config["kv_n_heads"]
|
||||
return self.hf_config.num_attention_heads
|
||||
if self.hf_config.model_type == "dbrx":
|
||||
return getattr(self.hf_config.attn_config, "kv_n_heads",
|
||||
self.hf_config.num_attention_heads)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user