From a84a9d5942634d596f44150aed3041d2337fcd02 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Sun, 1 Dec 2024 20:26:40 +0000 Subject: [PATCH] now handle TP + PP meta device --- picotron/checkpoint.py | 99 ++++++++++++++----- picotron/context_parallel/context_parallel.py | 1 + picotron/model.py | 1 - picotron/process_group_manager.py | 26 +++-- train.py | 2 +- 5 files changed, 86 insertions(+), 43 deletions(-) diff --git a/picotron/checkpoint.py b/picotron/checkpoint.py index 9b429c8..17762f9 100644 --- a/picotron/checkpoint.py +++ b/picotron/checkpoint.py @@ -10,6 +10,8 @@ import contextlib from picotron.utils import assert_no_meta_tensors import picotron.process_group_manager as pgm +from picotron.pipeline_parallel.pipeline_parallel import PipelineParallel + @contextlib.contextmanager def init_model_with_dematerialized_weights(include_buffers: bool = False): """ @@ -85,10 +87,11 @@ def init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_ dist.barrier() model.load_state_dict(state_dict, strict=True, assign=True) dist.barrier() - + assert_no_meta_tensors(model) + # Initialize model parameters initialization_manager.init_model_parameters() - + dist.barrier() return model class InitializationManager: @@ -115,43 +118,87 @@ class InitializationManager: # Generate base layer names layer_names = [] - #TODO: what if there is only tensor parallel that is activated ? - base_names = [f"model.layers.{id}" for id in self.model.layer_distribution] + if isinstance(self.model, PipelineParallel): + base_names = [f"model.layers.{id}" for id in self.model.layer_distribution] + else: + base_names = [f"model.layers.{id}" for id in range(self.model_config.num_hidden_layers)] + for layer in base_names: layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components) - # Add special layers based on pipeline stage - if pgm.process_group_manager.pp_is_first_stage: + # Add special layers based on pipeline stage or non-PP case + if isinstance(self.model, PipelineParallel): + if pgm.process_group_manager.pp_is_first_stage: + layer_names.insert(0, "model.embed_tokens.weight") + elif pgm.process_group_manager.pp_is_last_stage: + layer_names.extend(["model.norm.weight", "lm_head.weight"]) + else: layer_names.insert(0, "model.embed_tokens.weight") - elif pgm.process_group_manager.pp_is_last_stage: layer_names.extend(["model.norm.weight", "lm_head.weight"]) - + return layer_names def adjust_tensor_size(self, tensor, name): - """Resize tensor based on architecture changes.""" - if 'attention' not in name: - return tensor - + """Resize tensor based on architecture changes and tensor parallelism.""" + tp_rank = pgm.process_group_manager.tp_rank + tp_size = pgm.process_group_manager.tp_world_size hidden_size = self.model_config.hidden_size - head_dim = hidden_size // self.model_config.num_attention_heads - if 'q_proj.weight' in name: - target_dim = self.model_config.num_attention_heads * head_dim - elif 'k_proj.weight' in name or 'v_proj.weight' in name: - target_dim = self.model_config.num_key_value_heads * head_dim - else: + # Handle embedding and final projection layers + if 'embedding.weight' in name or 'final_proj.weight' in name: + vocab_size = self.model_config.vocab_size + vocab_per_rank = vocab_size // tp_size + if tensor.shape[0] != vocab_per_rank: + start_idx = tp_rank * vocab_per_rank + end_idx = start_idx + vocab_per_rank + tensor = tensor[start_idx:end_idx, :] return tensor - - # Adjust tensor size if needed - if tensor.shape[0] != target_dim: - if target_dim > tensor.shape[0]: - pad_tensor = torch.empty(target_dim - tensor.shape[0], tensor.shape[1], - dtype=tensor.dtype, device=tensor.device) - tensor = torch.cat([tensor, pad_tensor], dim=0) + + # Handle attention layers + if 'attention' in name: + head_dim = hidden_size // self.model_config.num_attention_heads + + if 'q_proj.weight' in name: + total_heads = self.model_config.num_attention_heads + heads_per_rank = total_heads // tp_size + target_dim = heads_per_rank * head_dim + elif 'k_proj.weight' in name or 'v_proj.weight' in name: + total_heads = self.model_config.num_key_value_heads + heads_per_rank = total_heads // tp_size + target_dim = heads_per_rank * head_dim + elif 'out_proj.weight' in name: + # For out_proj, we split along the second dimension + target_dim = tensor.shape[0] # First dimension stays the same + if tensor.shape[1] != hidden_size // tp_size: + tensor = tensor[:, (hidden_size // tp_size) * tp_rank:(hidden_size // tp_size) * (tp_rank + 1)] + return tensor else: - tensor = tensor[:target_dim, :] + return tensor + if tensor.shape[0] != target_dim: + if target_dim > tensor.shape[0]: + pad_tensor = torch.empty(target_dim - tensor.shape[0], tensor.shape[1], + dtype=tensor.dtype, device=tensor.device) + tensor = torch.cat([tensor, pad_tensor], dim=0) + else: + tensor = tensor[:target_dim, :] + + # Handle MLP layers + elif 'mlp' in name: + intermediate_size = self.model_config.intermediate_size + intermediate_size_per_rank = intermediate_size // tp_size + + if 'up_proj.weight' in name or 'gate_proj.weight' in name: + if tensor.shape[0] != intermediate_size_per_rank: + start_idx = tp_rank * intermediate_size_per_rank + end_idx = start_idx + intermediate_size_per_rank + tensor = tensor[start_idx:end_idx, :] + elif 'down_proj.weight' in name: + if tensor.shape[1] != intermediate_size_per_rank: + start_idx = tp_rank * intermediate_size_per_rank + end_idx = start_idx + intermediate_size_per_rank + tensor = tensor[:, start_idx:end_idx] + return tensor def convert_safetensors_to_hf_name(self, sft_name): diff --git a/picotron/context_parallel/context_parallel.py b/picotron/context_parallel/context_parallel.py index 357d907..3c16831 100644 --- a/picotron/context_parallel/context_parallel.py +++ b/picotron/context_parallel/context_parallel.py @@ -19,6 +19,7 @@ class RingAttentionFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, sm_scale, is_causal): comm = ContextCommunicate("comm") + #TODO(fmom): add flex attention #TODO(fmom): add flash attention #TODO(fmom): Find a better to save these tensors without cloning k_og = k.clone() diff --git a/picotron/model.py b/picotron/model.py index 6fabfd0..ac777f8 100644 --- a/picotron/model.py +++ b/picotron/model.py @@ -160,7 +160,6 @@ class Attention(nn.Module): out = self.out_proj(out) # [batch_size, seq_length, hidden_dim] return out - class MLP(nn.Module): def __init__(self, config) -> None: super().__init__() diff --git a/picotron/process_group_manager.py b/picotron/process_group_manager.py index 3ec1524..17014e3 100644 --- a/picotron/process_group_manager.py +++ b/picotron/process_group_manager.py @@ -8,13 +8,9 @@ class ProcessGroupManager: self.world_size = dist.get_world_size() self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size)) - self.tp_size = tp_size - self.cp_size = cp_size - self.pp_size = pp_size - self.dp_size = dp_size - assert self.world_size == self.tp_size * self.cp_size * self.pp_size * self.dp_size, f"World size ({self.world_size}) != TP ({self.tp_size}) * CP ({self.cp_size}) * PP ({self.pp_size}) * DP ({self.dp_size})" + assert self.world_size == tp_size * cp_size * pp_size * dp_size, f"World size ({self.world_size}) != TP ({tp_size}) * CP ({cp_size}) * PP ({pp_size}) * DP ({dp_size})" - self.grid = torch.arange(self.world_size).view(self.dp_size, self.pp_size, self.cp_size, self.tp_size) # DP * PP * CP * TP grid + self.grid = torch.arange(self.world_size).view(dp_size, pp_size, cp_size, tp_size) # DP * PP * CP * TP grid # Find the position of the current process in the grid self.dp_rank, self.pp_rank, self.cp_rank, self.tp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() @@ -36,36 +32,36 @@ class ProcessGroupManager: self.cp_dp_group_ids = self.grid[:, self.pp_rank, :, self.tp_rank].flatten().tolist() # Tensor parallelism + self.tp_world_size = dist.get_world_size(group=self.tp_group) self.tp_first_rank = self.tp_group_ids[0] self.tp_last_rank = self.tp_group_ids[-1] - self.tp_world_size = dist.get_world_size(group=self.tp_group) # Context parallelism + self.cp_world_size = dist.get_world_size(group=self.cp_group) self.cp_first_rank = self.cp_group_ids[0] self.cp_last_rank = self.cp_group_ids[-1] - self.cp_world_size = dist.get_world_size(group=self.cp_group) - self.cp_send_rank = self.cp_group_ids[(self.cp_rank + 1) % self.cp_size] - self.cp_recv_rank = self.cp_group_ids[(self.cp_rank - 1) % self.cp_size] + self.cp_send_rank = self.cp_group_ids[(self.cp_rank + 1) % self.cp_world_size] + self.cp_recv_rank = self.cp_group_ids[(self.cp_rank - 1) % self.cp_world_size] # Pipeline parallelism + self.pp_world_size = dist.get_world_size(group=self.pp_group) self.pp_first_rank = self.pp_group_ids[0] self.pp_last_rank = self.pp_group_ids[-1] self.pp_is_first_stage = self.pp_rank == 0 - self.pp_is_last_stage = self.pp_rank == self.pp_size - 1 - self.pp_next_rank = None if self.pp_rank == self.pp_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.cp_rank, self.tp_rank].item()) + self.pp_is_last_stage = self.pp_rank == self.pp_world_size - 1 + self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.cp_rank, self.tp_rank].item()) self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.dp_rank, self.pp_rank - 1, self.cp_rank, self.tp_rank].item()) - self.pp_world_size = dist.get_world_size(group=self.pp_group) # Data parallelism + self.dp_world_size = dist.get_world_size(group=self.dp_group) self.dp_first_rank = self.dp_group_ids[0] self.dp_last_rank = self.dp_group_ids[-1] - self.dp_world_size = dist.get_world_size(group=self.dp_group) # Context + Data paralellism self.cp_dp_world_size = dist.get_world_size(group=self.cp_dp_group) def __str__(self): - return f"TP({self.tp_size})-CP({self.cp_size})-PP({self.pp_size})-DP({self.dp_size})-Rank({self.global_rank})" + return f"TP({self.tp_world_size})-CP({self.cp_world_size})-PP({self.pp_world_size})-DP({self.dp_world_size})-Rank({self.global_rank})" def setup_process_group_manager(tp_size, cp_size, pp_size, dp_size): global process_group_manager diff --git a/train.py b/train.py index a2a1733..f334085 100644 --- a/train.py +++ b/train.py @@ -162,7 +162,7 @@ if __name__ == "__main__": model = init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"]) - # TODO: load existing checkpoint here to continue pre-training + #TODO: load existing checkpoint here to continue pre-training if pgm.process_group_manager.cp_world_size > 1: model = apply_context_parallel(model)