now handle TP + PP meta device
This commit is contained in:
parent
bccee5d037
commit
a84a9d5942
@ -10,6 +10,8 @@ import contextlib
|
||||
from picotron.utils import assert_no_meta_tensors
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
from picotron.pipeline_parallel.pipeline_parallel import PipelineParallel
|
||||
|
||||
@contextlib.contextmanager
|
||||
def init_model_with_dematerialized_weights(include_buffers: bool = False):
|
||||
"""
|
||||
@ -85,10 +87,11 @@ def init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_
|
||||
dist.barrier()
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
dist.barrier()
|
||||
|
||||
|
||||
assert_no_meta_tensors(model)
|
||||
# Initialize model parameters
|
||||
initialization_manager.init_model_parameters()
|
||||
|
||||
dist.barrier()
|
||||
return model
|
||||
|
||||
class InitializationManager:
|
||||
@ -115,43 +118,87 @@ class InitializationManager:
|
||||
|
||||
# Generate base layer names
|
||||
layer_names = []
|
||||
#TODO: what if there is only tensor parallel that is activated ?
|
||||
base_names = [f"model.layers.{id}" for id in self.model.layer_distribution]
|
||||
if isinstance(self.model, PipelineParallel):
|
||||
base_names = [f"model.layers.{id}" for id in self.model.layer_distribution]
|
||||
else:
|
||||
base_names = [f"model.layers.{id}" for id in range(self.model_config.num_hidden_layers)]
|
||||
|
||||
for layer in base_names:
|
||||
layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components)
|
||||
|
||||
# Add special layers based on pipeline stage
|
||||
if pgm.process_group_manager.pp_is_first_stage:
|
||||
# Add special layers based on pipeline stage or non-PP case
|
||||
if isinstance(self.model, PipelineParallel):
|
||||
if pgm.process_group_manager.pp_is_first_stage:
|
||||
layer_names.insert(0, "model.embed_tokens.weight")
|
||||
elif pgm.process_group_manager.pp_is_last_stage:
|
||||
layer_names.extend(["model.norm.weight", "lm_head.weight"])
|
||||
else:
|
||||
layer_names.insert(0, "model.embed_tokens.weight")
|
||||
elif pgm.process_group_manager.pp_is_last_stage:
|
||||
layer_names.extend(["model.norm.weight", "lm_head.weight"])
|
||||
|
||||
|
||||
return layer_names
|
||||
|
||||
def adjust_tensor_size(self, tensor, name):
|
||||
"""Resize tensor based on architecture changes."""
|
||||
if 'attention' not in name:
|
||||
return tensor
|
||||
|
||||
"""Resize tensor based on architecture changes and tensor parallelism."""
|
||||
tp_rank = pgm.process_group_manager.tp_rank
|
||||
tp_size = pgm.process_group_manager.tp_world_size
|
||||
hidden_size = self.model_config.hidden_size
|
||||
head_dim = hidden_size // self.model_config.num_attention_heads
|
||||
|
||||
if 'q_proj.weight' in name:
|
||||
target_dim = self.model_config.num_attention_heads * head_dim
|
||||
elif 'k_proj.weight' in name or 'v_proj.weight' in name:
|
||||
target_dim = self.model_config.num_key_value_heads * head_dim
|
||||
else:
|
||||
# Handle embedding and final projection layers
|
||||
if 'embedding.weight' in name or 'final_proj.weight' in name:
|
||||
vocab_size = self.model_config.vocab_size
|
||||
vocab_per_rank = vocab_size // tp_size
|
||||
if tensor.shape[0] != vocab_per_rank:
|
||||
start_idx = tp_rank * vocab_per_rank
|
||||
end_idx = start_idx + vocab_per_rank
|
||||
tensor = tensor[start_idx:end_idx, :]
|
||||
return tensor
|
||||
|
||||
# Adjust tensor size if needed
|
||||
if tensor.shape[0] != target_dim:
|
||||
if target_dim > tensor.shape[0]:
|
||||
pad_tensor = torch.empty(target_dim - tensor.shape[0], tensor.shape[1],
|
||||
dtype=tensor.dtype, device=tensor.device)
|
||||
tensor = torch.cat([tensor, pad_tensor], dim=0)
|
||||
|
||||
# Handle attention layers
|
||||
if 'attention' in name:
|
||||
head_dim = hidden_size // self.model_config.num_attention_heads
|
||||
|
||||
if 'q_proj.weight' in name:
|
||||
total_heads = self.model_config.num_attention_heads
|
||||
heads_per_rank = total_heads // tp_size
|
||||
target_dim = heads_per_rank * head_dim
|
||||
elif 'k_proj.weight' in name or 'v_proj.weight' in name:
|
||||
total_heads = self.model_config.num_key_value_heads
|
||||
heads_per_rank = total_heads // tp_size
|
||||
target_dim = heads_per_rank * head_dim
|
||||
elif 'out_proj.weight' in name:
|
||||
# For out_proj, we split along the second dimension
|
||||
target_dim = tensor.shape[0] # First dimension stays the same
|
||||
if tensor.shape[1] != hidden_size // tp_size:
|
||||
tensor = tensor[:, (hidden_size // tp_size) * tp_rank:(hidden_size // tp_size) * (tp_rank + 1)]
|
||||
return tensor
|
||||
else:
|
||||
tensor = tensor[:target_dim, :]
|
||||
return tensor
|
||||
|
||||
if tensor.shape[0] != target_dim:
|
||||
if target_dim > tensor.shape[0]:
|
||||
pad_tensor = torch.empty(target_dim - tensor.shape[0], tensor.shape[1],
|
||||
dtype=tensor.dtype, device=tensor.device)
|
||||
tensor = torch.cat([tensor, pad_tensor], dim=0)
|
||||
else:
|
||||
tensor = tensor[:target_dim, :]
|
||||
|
||||
# Handle MLP layers
|
||||
elif 'mlp' in name:
|
||||
intermediate_size = self.model_config.intermediate_size
|
||||
intermediate_size_per_rank = intermediate_size // tp_size
|
||||
|
||||
if 'up_proj.weight' in name or 'gate_proj.weight' in name:
|
||||
if tensor.shape[0] != intermediate_size_per_rank:
|
||||
start_idx = tp_rank * intermediate_size_per_rank
|
||||
end_idx = start_idx + intermediate_size_per_rank
|
||||
tensor = tensor[start_idx:end_idx, :]
|
||||
elif 'down_proj.weight' in name:
|
||||
if tensor.shape[1] != intermediate_size_per_rank:
|
||||
start_idx = tp_rank * intermediate_size_per_rank
|
||||
end_idx = start_idx + intermediate_size_per_rank
|
||||
tensor = tensor[:, start_idx:end_idx]
|
||||
|
||||
return tensor
|
||||
|
||||
def convert_safetensors_to_hf_name(self, sft_name):
|
||||
|
||||
@ -19,6 +19,7 @@ class RingAttentionFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale, is_causal):
|
||||
comm = ContextCommunicate("comm")
|
||||
#TODO(fmom): add flex attention
|
||||
#TODO(fmom): add flash attention
|
||||
#TODO(fmom): Find a better to save these tensors without cloning
|
||||
k_og = k.clone()
|
||||
|
||||
@ -160,7 +160,6 @@ class Attention(nn.Module):
|
||||
out = self.out_proj(out) # [batch_size, seq_length, hidden_dim]
|
||||
return out
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -8,13 +8,9 @@ class ProcessGroupManager:
|
||||
self.world_size = dist.get_world_size()
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size))
|
||||
|
||||
self.tp_size = tp_size
|
||||
self.cp_size = cp_size
|
||||
self.pp_size = pp_size
|
||||
self.dp_size = dp_size
|
||||
assert self.world_size == self.tp_size * self.cp_size * self.pp_size * self.dp_size, f"World size ({self.world_size}) != TP ({self.tp_size}) * CP ({self.cp_size}) * PP ({self.pp_size}) * DP ({self.dp_size})"
|
||||
assert self.world_size == tp_size * cp_size * pp_size * dp_size, f"World size ({self.world_size}) != TP ({tp_size}) * CP ({cp_size}) * PP ({pp_size}) * DP ({dp_size})"
|
||||
|
||||
self.grid = torch.arange(self.world_size).view(self.dp_size, self.pp_size, self.cp_size, self.tp_size) # DP * PP * CP * TP grid
|
||||
self.grid = torch.arange(self.world_size).view(dp_size, pp_size, cp_size, tp_size) # DP * PP * CP * TP grid
|
||||
# Find the position of the current process in the grid
|
||||
self.dp_rank, self.pp_rank, self.cp_rank, self.tp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist()
|
||||
|
||||
@ -36,36 +32,36 @@ class ProcessGroupManager:
|
||||
self.cp_dp_group_ids = self.grid[:, self.pp_rank, :, self.tp_rank].flatten().tolist()
|
||||
|
||||
# Tensor parallelism
|
||||
self.tp_world_size = dist.get_world_size(group=self.tp_group)
|
||||
self.tp_first_rank = self.tp_group_ids[0]
|
||||
self.tp_last_rank = self.tp_group_ids[-1]
|
||||
self.tp_world_size = dist.get_world_size(group=self.tp_group)
|
||||
|
||||
# Context parallelism
|
||||
self.cp_world_size = dist.get_world_size(group=self.cp_group)
|
||||
self.cp_first_rank = self.cp_group_ids[0]
|
||||
self.cp_last_rank = self.cp_group_ids[-1]
|
||||
self.cp_world_size = dist.get_world_size(group=self.cp_group)
|
||||
self.cp_send_rank = self.cp_group_ids[(self.cp_rank + 1) % self.cp_size]
|
||||
self.cp_recv_rank = self.cp_group_ids[(self.cp_rank - 1) % self.cp_size]
|
||||
self.cp_send_rank = self.cp_group_ids[(self.cp_rank + 1) % self.cp_world_size]
|
||||
self.cp_recv_rank = self.cp_group_ids[(self.cp_rank - 1) % self.cp_world_size]
|
||||
|
||||
# Pipeline parallelism
|
||||
self.pp_world_size = dist.get_world_size(group=self.pp_group)
|
||||
self.pp_first_rank = self.pp_group_ids[0]
|
||||
self.pp_last_rank = self.pp_group_ids[-1]
|
||||
self.pp_is_first_stage = self.pp_rank == 0
|
||||
self.pp_is_last_stage = self.pp_rank == self.pp_size - 1
|
||||
self.pp_next_rank = None if self.pp_rank == self.pp_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.cp_rank, self.tp_rank].item())
|
||||
self.pp_is_last_stage = self.pp_rank == self.pp_world_size - 1
|
||||
self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.cp_rank, self.tp_rank].item())
|
||||
self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.dp_rank, self.pp_rank - 1, self.cp_rank, self.tp_rank].item())
|
||||
self.pp_world_size = dist.get_world_size(group=self.pp_group)
|
||||
|
||||
# Data parallelism
|
||||
self.dp_world_size = dist.get_world_size(group=self.dp_group)
|
||||
self.dp_first_rank = self.dp_group_ids[0]
|
||||
self.dp_last_rank = self.dp_group_ids[-1]
|
||||
self.dp_world_size = dist.get_world_size(group=self.dp_group)
|
||||
|
||||
# Context + Data paralellism
|
||||
self.cp_dp_world_size = dist.get_world_size(group=self.cp_dp_group)
|
||||
|
||||
def __str__(self):
|
||||
return f"TP({self.tp_size})-CP({self.cp_size})-PP({self.pp_size})-DP({self.dp_size})-Rank({self.global_rank})"
|
||||
return f"TP({self.tp_world_size})-CP({self.cp_world_size})-PP({self.pp_world_size})-DP({self.dp_world_size})-Rank({self.global_rank})"
|
||||
|
||||
def setup_process_group_manager(tp_size, cp_size, pp_size, dp_size):
|
||||
global process_group_manager
|
||||
|
||||
2
train.py
2
train.py
@ -162,7 +162,7 @@ if __name__ == "__main__":
|
||||
|
||||
model = init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"])
|
||||
|
||||
# TODO: load existing checkpoint here to continue pre-training
|
||||
#TODO: load existing checkpoint here to continue pre-training
|
||||
|
||||
if pgm.process_group_manager.cp_world_size > 1:
|
||||
model = apply_context_parallel(model)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user