diff --git a/vllm/config.py b/vllm/config.py index 06ade925..567bb44b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -114,8 +114,9 @@ class ModelConfig: # Note: for falcon, when new_decoder_architecture is True, the # multi_query flag is ignored and we use n_head_kv for the number of # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] new_decoder_arch_falcon = ( - self.hf_config.model_type == "falcon" + self.hf_config.model_type in falcon_model_types and getattr(self.hf_config, "new_decoder_architecture", False)) if not new_decoder_arch_falcon and getattr(self.hf_config, "multi_query", False):