diff --git a/vllm/config.py b/vllm/config.py index f3e204af..0e3f4ac8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -135,7 +135,8 @@ class ModelConfig: # FIXME(woosuk): This may not be true for all models. return self.hf_config.hidden_size // self.hf_config.num_attention_heads - def get_num_heads(self, parallel_config: "ParallelConfig") -> int: + def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: + """Returns the number of KV heads per GPU worker.""" # For GPTBigCode & Falcon: # Note: for falcon, when new_decoder_architecture is True, the # multi_query flag is ignored and we use n_head_kv for the number of @@ -147,11 +148,15 @@ class ModelConfig: if not new_decoder_arch_falcon and getattr(self.hf_config, "multi_query", False): # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. return 1 # For Falcon: if getattr(self.hf_config, "n_head_kv", None) is not None: return (self.hf_config.n_head_kv // parallel_config.tensor_parallel_size) + if getattr(self.hf_config, "num_kv_heads", None) is not None: + return (self.hf_config.num_kv_heads // + parallel_config.tensor_parallel_size) # For LLaMA-2: if getattr(self.hf_config, "num_key_value_heads", None) is not None: return (self.hf_config.num_key_value_heads // diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 3d5a723d..cdb79020 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -33,7 +33,7 @@ class CacheEngine: self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) - self.num_heads = model_config.get_num_heads(parallel_config) + self.num_heads = model_config.get_num_kv_heads(parallel_config) self.dtype = model_config.dtype self.block_size = cache_config.block_size @@ -146,7 +146,7 @@ class CacheEngine: parallel_config: ParallelConfig, ) -> int: head_size = model_config.get_head_size() - num_heads = model_config.get_num_heads(parallel_config) + num_heads = model_config.get_num_kv_heads(parallel_config) num_layers = model_config.get_num_layers(parallel_config) key_cache_block = block_size * num_heads * head_size