From 898285c9bf306e32c7f161b75bf0bf7fd483f265 Mon Sep 17 00:00:00 2001 From: Kyujin Cho Date: Sun, 10 Sep 2023 17:39:02 +0900 Subject: [PATCH] fix: CUDA error when inferencing with Falcon-40B base model (#992) --- vllm/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 06ade925..567bb44b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -114,8 +114,9 @@ class ModelConfig: # Note: for falcon, when new_decoder_architecture is True, the # multi_query flag is ignored and we use n_head_kv for the number of # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] new_decoder_arch_falcon = ( - self.hf_config.model_type == "falcon" + self.hf_config.model_type in falcon_model_types and getattr(self.hf_config, "new_decoder_architecture", False)) if not new_decoder_arch_falcon and getattr(self.hf_config, "multi_query", False):