From b72af8f1eded6f5838be29eb6093ab0e0e0c240c Mon Sep 17 00:00:00 2001 From: zhaoyang-star Date: Mon, 29 Jan 2024 14:47:39 +0800 Subject: [PATCH] Fix error when tp > 1 (#2644) Co-authored-by: zhaoyang-star --- vllm/engine/llm_engine.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5b73ef08..0d836a1f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -236,7 +236,6 @@ class LLMEngine: model_config = copy.deepcopy(self.model_config) parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) - cache_config = copy.deepcopy(self.cache_config) for rank, (worker, (node_id, _)) in enumerate(zip(self.workers, @@ -252,7 +251,7 @@ class LLMEngine: rank, distributed_init_method, lora_config=self.lora_config, - cache_config=cache_config, + kv_cache_dtype=self.cache_config.cache_dtype, )) driver_rank = 0 @@ -265,7 +264,7 @@ class LLMEngine: driver_rank, distributed_init_method, lora_config=self.lora_config, - cache_config=cache_config, + kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, )