From 3c635092f9de3237385073f6e1484c63ed256d6f Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 30 Oct 2024 14:04:45 +0000 Subject: [PATCH] add assert in TensorParallel for num_attention_heads and key_values_heads --- model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/model.py b/model.py index 83417ac..934e768 100644 --- a/model.py +++ b/model.py @@ -81,9 +81,10 @@ class Attention(nn.Module): self.num_heads = config.num_attention_heads self.num_key_values = config.num_key_value_heads self.head_dim = self.hidden_size//self.num_heads - model_parallel_size = pgm.process_group_manager.tp_world_size - self.num_local_heads = config.num_attention_heads // model_parallel_size # TP parallelism - self.num_local_kv_heads = config.num_key_value_heads // model_parallel_size # TP parallelism + assert config.num_attention_heads % pgm.process_group_manager.tp_world_size == 0, "num_attention_heads should be divisible by tp world size" + assert config.num_key_value_heads % pgm.process_group_manager.tp_world_size == 0, "num_key_value_heads should be divisible by tp world size" + self.num_local_heads = config.num_attention_heads // pgm.process_group_manager.tp_world_size # TP parallelism + self.num_local_kv_heads = config.num_key_value_heads // pgm.process_group_manager.tp_world_size # TP parallelism self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)