now handle TP + PP meta device

This commit is contained in:
ferdinand.mom 2024-12-01 20:26:40 +00:00
parent bccee5d037
commit a84a9d5942
5 changed files with 86 additions and 43 deletions

View File

@ -10,6 +10,8 @@ import contextlib
from picotron.utils import assert_no_meta_tensors
import picotron.process_group_manager as pgm
from picotron.pipeline_parallel.pipeline_parallel import PipelineParallel
@contextlib.contextmanager
def init_model_with_dematerialized_weights(include_buffers: bool = False):
"""
@ -85,10 +87,11 @@ def init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_
dist.barrier()
model.load_state_dict(state_dict, strict=True, assign=True)
dist.barrier()
assert_no_meta_tensors(model)
# Initialize model parameters
initialization_manager.init_model_parameters()
dist.barrier()
return model
class InitializationManager:
@ -115,43 +118,87 @@ class InitializationManager:
# Generate base layer names
layer_names = []
#TODO: what if there is only tensor parallel that is activated ?
base_names = [f"model.layers.{id}" for id in self.model.layer_distribution]
if isinstance(self.model, PipelineParallel):
base_names = [f"model.layers.{id}" for id in self.model.layer_distribution]
else:
base_names = [f"model.layers.{id}" for id in range(self.model_config.num_hidden_layers)]
for layer in base_names:
layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components)
# Add special layers based on pipeline stage
if pgm.process_group_manager.pp_is_first_stage:
# Add special layers based on pipeline stage or non-PP case
if isinstance(self.model, PipelineParallel):
if pgm.process_group_manager.pp_is_first_stage:
layer_names.insert(0, "model.embed_tokens.weight")
elif pgm.process_group_manager.pp_is_last_stage:
layer_names.extend(["model.norm.weight", "lm_head.weight"])
else:
layer_names.insert(0, "model.embed_tokens.weight")
elif pgm.process_group_manager.pp_is_last_stage:
layer_names.extend(["model.norm.weight", "lm_head.weight"])
return layer_names
def adjust_tensor_size(self, tensor, name):
"""Resize tensor based on architecture changes."""
if 'attention' not in name:
return tensor
"""Resize tensor based on architecture changes and tensor parallelism."""
tp_rank = pgm.process_group_manager.tp_rank
tp_size = pgm.process_group_manager.tp_world_size
hidden_size = self.model_config.hidden_size
head_dim = hidden_size // self.model_config.num_attention_heads
if 'q_proj.weight' in name:
target_dim = self.model_config.num_attention_heads * head_dim
elif 'k_proj.weight' in name or 'v_proj.weight' in name:
target_dim = self.model_config.num_key_value_heads * head_dim
else:
# Handle embedding and final projection layers
if 'embedding.weight' in name or 'final_proj.weight' in name:
vocab_size = self.model_config.vocab_size
vocab_per_rank = vocab_size // tp_size
if tensor.shape[0] != vocab_per_rank:
start_idx = tp_rank * vocab_per_rank
end_idx = start_idx + vocab_per_rank
tensor = tensor[start_idx:end_idx, :]
return tensor
# Adjust tensor size if needed
if tensor.shape[0] != target_dim:
if target_dim > tensor.shape[0]:
pad_tensor = torch.empty(target_dim - tensor.shape[0], tensor.shape[1],
dtype=tensor.dtype, device=tensor.device)
tensor = torch.cat([tensor, pad_tensor], dim=0)
# Handle attention layers
if 'attention' in name:
head_dim = hidden_size // self.model_config.num_attention_heads
if 'q_proj.weight' in name:
total_heads = self.model_config.num_attention_heads
heads_per_rank = total_heads // tp_size
target_dim = heads_per_rank * head_dim
elif 'k_proj.weight' in name or 'v_proj.weight' in name:
total_heads = self.model_config.num_key_value_heads
heads_per_rank = total_heads // tp_size
target_dim = heads_per_rank * head_dim
elif 'out_proj.weight' in name:
# For out_proj, we split along the second dimension
target_dim = tensor.shape[0] # First dimension stays the same
if tensor.shape[1] != hidden_size // tp_size:
tensor = tensor[:, (hidden_size // tp_size) * tp_rank:(hidden_size // tp_size) * (tp_rank + 1)]
return tensor
else:
tensor = tensor[:target_dim, :]
return tensor
if tensor.shape[0] != target_dim:
if target_dim > tensor.shape[0]:
pad_tensor = torch.empty(target_dim - tensor.shape[0], tensor.shape[1],
dtype=tensor.dtype, device=tensor.device)
tensor = torch.cat([tensor, pad_tensor], dim=0)
else:
tensor = tensor[:target_dim, :]
# Handle MLP layers
elif 'mlp' in name:
intermediate_size = self.model_config.intermediate_size
intermediate_size_per_rank = intermediate_size // tp_size
if 'up_proj.weight' in name or 'gate_proj.weight' in name:
if tensor.shape[0] != intermediate_size_per_rank:
start_idx = tp_rank * intermediate_size_per_rank
end_idx = start_idx + intermediate_size_per_rank
tensor = tensor[start_idx:end_idx, :]
elif 'down_proj.weight' in name:
if tensor.shape[1] != intermediate_size_per_rank:
start_idx = tp_rank * intermediate_size_per_rank
end_idx = start_idx + intermediate_size_per_rank
tensor = tensor[:, start_idx:end_idx]
return tensor
def convert_safetensors_to_hf_name(self, sft_name):

View File

@ -19,6 +19,7 @@ class RingAttentionFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale, is_causal):
comm = ContextCommunicate("comm")
#TODO(fmom): add flex attention
#TODO(fmom): add flash attention
#TODO(fmom): Find a better to save these tensors without cloning
k_og = k.clone()

View File

@ -160,7 +160,6 @@ class Attention(nn.Module):
out = self.out_proj(out) # [batch_size, seq_length, hidden_dim]
return out
class MLP(nn.Module):
def __init__(self, config) -> None:
super().__init__()

View File

@ -8,13 +8,9 @@ class ProcessGroupManager:
self.world_size = dist.get_world_size()
self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size))
self.tp_size = tp_size
self.cp_size = cp_size
self.pp_size = pp_size
self.dp_size = dp_size
assert self.world_size == self.tp_size * self.cp_size * self.pp_size * self.dp_size, f"World size ({self.world_size}) != TP ({self.tp_size}) * CP ({self.cp_size}) * PP ({self.pp_size}) * DP ({self.dp_size})"
assert self.world_size == tp_size * cp_size * pp_size * dp_size, f"World size ({self.world_size}) != TP ({tp_size}) * CP ({cp_size}) * PP ({pp_size}) * DP ({dp_size})"
self.grid = torch.arange(self.world_size).view(self.dp_size, self.pp_size, self.cp_size, self.tp_size) # DP * PP * CP * TP grid
self.grid = torch.arange(self.world_size).view(dp_size, pp_size, cp_size, tp_size) # DP * PP * CP * TP grid
# Find the position of the current process in the grid
self.dp_rank, self.pp_rank, self.cp_rank, self.tp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist()
@ -36,36 +32,36 @@ class ProcessGroupManager:
self.cp_dp_group_ids = self.grid[:, self.pp_rank, :, self.tp_rank].flatten().tolist()
# Tensor parallelism
self.tp_world_size = dist.get_world_size(group=self.tp_group)
self.tp_first_rank = self.tp_group_ids[0]
self.tp_last_rank = self.tp_group_ids[-1]
self.tp_world_size = dist.get_world_size(group=self.tp_group)
# Context parallelism
self.cp_world_size = dist.get_world_size(group=self.cp_group)
self.cp_first_rank = self.cp_group_ids[0]
self.cp_last_rank = self.cp_group_ids[-1]
self.cp_world_size = dist.get_world_size(group=self.cp_group)
self.cp_send_rank = self.cp_group_ids[(self.cp_rank + 1) % self.cp_size]
self.cp_recv_rank = self.cp_group_ids[(self.cp_rank - 1) % self.cp_size]
self.cp_send_rank = self.cp_group_ids[(self.cp_rank + 1) % self.cp_world_size]
self.cp_recv_rank = self.cp_group_ids[(self.cp_rank - 1) % self.cp_world_size]
# Pipeline parallelism
self.pp_world_size = dist.get_world_size(group=self.pp_group)
self.pp_first_rank = self.pp_group_ids[0]
self.pp_last_rank = self.pp_group_ids[-1]
self.pp_is_first_stage = self.pp_rank == 0
self.pp_is_last_stage = self.pp_rank == self.pp_size - 1
self.pp_next_rank = None if self.pp_rank == self.pp_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.cp_rank, self.tp_rank].item())
self.pp_is_last_stage = self.pp_rank == self.pp_world_size - 1
self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.cp_rank, self.tp_rank].item())
self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.dp_rank, self.pp_rank - 1, self.cp_rank, self.tp_rank].item())
self.pp_world_size = dist.get_world_size(group=self.pp_group)
# Data parallelism
self.dp_world_size = dist.get_world_size(group=self.dp_group)
self.dp_first_rank = self.dp_group_ids[0]
self.dp_last_rank = self.dp_group_ids[-1]
self.dp_world_size = dist.get_world_size(group=self.dp_group)
# Context + Data paralellism
self.cp_dp_world_size = dist.get_world_size(group=self.cp_dp_group)
def __str__(self):
return f"TP({self.tp_size})-CP({self.cp_size})-PP({self.pp_size})-DP({self.dp_size})-Rank({self.global_rank})"
return f"TP({self.tp_world_size})-CP({self.cp_world_size})-PP({self.pp_world_size})-DP({self.dp_world_size})-Rank({self.global_rank})"
def setup_process_group_manager(tp_size, cp_size, pp_size, dp_size):
global process_group_manager

View File

@ -162,7 +162,7 @@ if __name__ == "__main__":
model = init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"])
# TODO: load existing checkpoint here to continue pre-training
#TODO: load existing checkpoint here to continue pre-training
if pgm.process_group_manager.cp_world_size > 1:
model = apply_context_parallel(model)