diff --git a/model.py b/model.py index 83417ac..934e768 100644 --- a/model.py +++ b/model.py @@ -81,9 +81,10 @@ class Attention(nn.Module): self.num_heads = config.num_attention_heads self.num_key_values = config.num_key_value_heads self.head_dim = self.hidden_size//self.num_heads - model_parallel_size = pgm.process_group_manager.tp_world_size - self.num_local_heads = config.num_attention_heads // model_parallel_size # TP parallelism - self.num_local_kv_heads = config.num_key_value_heads // model_parallel_size # TP parallelism + assert config.num_attention_heads % pgm.process_group_manager.tp_world_size == 0, "num_attention_heads should be divisible by tp world size" + assert config.num_key_value_heads % pgm.process_group_manager.tp_world_size == 0, "num_key_value_heads should be divisible by tp world size" + self.num_local_heads = config.num_attention_heads // pgm.process_group_manager.tp_world_size # TP parallelism + self.num_local_kv_heads = config.num_key_value_heads // pgm.process_group_manager.tp_world_size # TP parallelism self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)