[Bugfix] Fix KV head calculation for MPT models when using GQA (#5142)
This commit is contained in:
parent
e441bad674
commit
a3e8a05d4c
@ -302,7 +302,11 @@ class ModelConfig:
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
# For DBRX and MPT
|
# For DBRX and MPT
|
||||||
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
if self.hf_config.model_type == "mpt":
|
||||||
|
if "kv_n_heads" in self.hf_config.attn_config:
|
||||||
|
return self.hf_config.attn_config["kv_n_heads"]
|
||||||
|
return self.hf_config.num_attention_heads
|
||||||
|
if self.hf_config.model_type == "dbrx":
|
||||||
return getattr(self.hf_config.attn_config, "kv_n_heads",
|
return getattr(self.hf_config.attn_config, "kv_n_heads",
|
||||||
self.hf_config.num_attention_heads)
|
self.hf_config.num_attention_heads)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user