add assert in TensorParallel for num_attention_heads and key_values_heads
This commit is contained in:
parent
1dbe034d57
commit
3c635092f9
7
model.py
7
model.py
@ -81,9 +81,10 @@ class Attention(nn.Module):
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_key_values = config.num_key_value_heads
|
||||
self.head_dim = self.hidden_size//self.num_heads
|
||||
model_parallel_size = pgm.process_group_manager.tp_world_size
|
||||
self.num_local_heads = config.num_attention_heads // model_parallel_size # TP parallelism
|
||||
self.num_local_kv_heads = config.num_key_value_heads // model_parallel_size # TP parallelism
|
||||
assert config.num_attention_heads % pgm.process_group_manager.tp_world_size == 0, "num_attention_heads should be divisible by tp world size"
|
||||
assert config.num_key_value_heads % pgm.process_group_manager.tp_world_size == 0, "num_key_value_heads should be divisible by tp world size"
|
||||
self.num_local_heads = config.num_attention_heads // pgm.process_group_manager.tp_world_size # TP parallelism
|
||||
self.num_local_kv_heads = config.num_key_value_heads // pgm.process_group_manager.tp_world_size # TP parallelism
|
||||
|
||||
self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user