add assert in TensorParallel for num_attention_heads and key_values_heads

This commit is contained in:
ferdinand.mom 2024-10-30 14:04:45 +00:00
parent 1dbe034d57
commit 3c635092f9

View File

@ -81,9 +81,10 @@ class Attention(nn.Module):
self.num_heads = config.num_attention_heads
self.num_key_values = config.num_key_value_heads
self.head_dim = self.hidden_size//self.num_heads
model_parallel_size = pgm.process_group_manager.tp_world_size
self.num_local_heads = config.num_attention_heads // model_parallel_size # TP parallelism
self.num_local_kv_heads = config.num_key_value_heads // model_parallel_size # TP parallelism
assert config.num_attention_heads % pgm.process_group_manager.tp_world_size == 0, "num_attention_heads should be divisible by tp world size"
assert config.num_key_value_heads % pgm.process_group_manager.tp_world_size == 0, "num_key_value_heads should be divisible by tp world size"
self.num_local_heads = config.num_attention_heads // pgm.process_group_manager.tp_world_size # TP parallelism
self.num_local_kv_heads = config.num_key_value_heads // pgm.process_group_manager.tp_world_size # TP parallelism
self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)