wip: load big model with meta device

This commit is contained in:
ferdinand.mom 2024-11-29 16:38:42 +00:00
parent 099621fd94
commit 5045be87e0
5 changed files with 379 additions and 13 deletions

183
extract_metrics.py Normal file
View File

@ -0,0 +1,183 @@
import re
import csv
import glob
import os
import argparse
import numpy as np
def parse_folder_name(folder_name):
dp = re.search(r'dp(\d+)', folder_name)
tp = re.search(r'tp(\d+)', folder_name)
pp = re.search(r'pp(\d+)', folder_name)
mbs = re.search(r'mbs(\d+)', folder_name)
ga = re.search(r'ga(\d+)', folder_name)
sl = re.search(r'sl(\d+)', folder_name)
return {
'dp': int(dp.group(1)) if dp else None,
'tp': int(tp.group(1)) if tp else None,
'pp': int(pp.group(1)) if pp else None,
'micro_batch_size': int(mbs.group(1)) if mbs else None,
'grad_acc': int(ga.group(1)) if ga else None,
'seq_len': int(sl.group(1)) if sl else None
}
def from_readable_format(formatted_str):
if not isinstance(formatted_str, str):
return formatted_str
# Remove any whitespace and convert to upper case for consistency
formatted_str = formatted_str.strip().upper()
# If it's just a number without suffix, return float
try:
return float(formatted_str)
except ValueError:
pass
# Define multipliers
multipliers = {
'T': 1e12,
'B': 1e9,
'M': 1e6,
'K': 1e3
}
# Extract number and suffix
number = float(formatted_str[:-1])
suffix = formatted_str[-1]
if suffix in multipliers:
return number * multipliers[suffix]
else:
raise ValueError(f"Unknown suffix: {suffix}")
def parse_log_line(line):
tokens_s_gpu_match = re.search(r'Tokens/s/GPU: ([\d.]+[KMBT]?)', line)
if tokens_s_gpu_match:
value = tokens_s_gpu_match.group(1)
return from_readable_format(value)
return None
def process_file(filepath):
tokens_s_gpu_values = []
with open(filepath, 'r') as f:
for line in f:
if '[default0]:[rank 0]' in line:
tokens_s_gpu = parse_log_line(line)
if tokens_s_gpu is not None:
tokens_s_gpu_values.append(tokens_s_gpu)
return int(round(np.mean(tokens_s_gpu_values))) if tokens_s_gpu_values else None
def write_csv(data, output_filepath):
if not data:
return
fieldnames = ['run_name', 'status', 'dp', 'tp', 'pp', 'micro_batch_size', 'grad_acc', 'seq_len', 'avg_tokens_s_gpu']
with open(output_filepath, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerow(data)
def read_status(status_file):
try:
with open(status_file, 'r') as f:
return f.read().strip()
except:
return None
def create_subdirectory_metrics(input_folder):
"""Create metrics.csv files in each subdirectory"""
pattern = os.path.join(input_folder, '**/*.out')
out_files = glob.glob(pattern, recursive=True)
print(f"Found {len(out_files)} .out files")
processed_dirs = []
for file_path in out_files:
dir_path = os.path.dirname(file_path)
dir_name = os.path.basename(dir_path)
output_csv = os.path.join(dir_path, 'metrics.csv')
params = parse_folder_name(dir_name)
avg_tokens_s_gpu = process_file(file_path)
status = read_status(os.path.join(dir_path, 'status.txt'))
params['run_name'] = dir_name
write_csv(params, output_csv)
if status is not None:
params['status'] = status
write_csv(params, output_csv)
if avg_tokens_s_gpu is not None:
params['avg_tokens_s_gpu'] = avg_tokens_s_gpu
write_csv(params, output_csv)
processed_dirs.append(dir_path)
print(f"Processed {file_path} -> Created metrics.csv")
return processed_dirs
def aggregate_metrics(input_folder):
"""Create global_metrics.csv from all subdirectory metrics"""
top_level_dir = glob.glob(input_folder + '/*')
for top_dir_path in top_level_dir:
subdirs = glob.glob(top_dir_path + '/*')
aggregated_data = []
for subdir_path in subdirs:
metrics_file = os.path.join(subdir_path, 'metrics.csv')
status_file = os.path.join(subdir_path, 'status.txt')
folder_name = os.path.basename(subdir_path)
data = {
'run_name': folder_name,
'status': read_status(status_file),
**parse_folder_name(folder_name) # Unpack the parsed parameters
}
# If metrics.csv exists, read the avg_tokens_s_gpu from it
if os.path.exists(metrics_file):
try:
with open(metrics_file, 'r') as f:
reader = csv.DictReader(f)
metrics_data = next(reader)
data['avg_tokens_s_gpu'] = int(metrics_data['avg_tokens_s_gpu'])
except:
data['avg_tokens_s_gpu'] = -1
else:
data['avg_tokens_s_gpu'] = -1
aggregated_data.append(data)
# Write global metrics file
output_file = os.path.join(top_dir_path, 'global_metrics.csv')
fieldnames = ['run_name', 'status', 'dp', 'tp', 'pp', 'micro_batch_size',
'grad_acc', 'seq_len', 'avg_tokens_s_gpu']
with open(output_file, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(aggregated_data)
print(f"Created global_metrics.csv with {len(aggregated_data)} entries")
def main():
parser = argparse.ArgumentParser(description='Process log files and create metrics CSVs')
parser.add_argument('input_folder', help='Path to the top-level folder containing experiment subfolders')
args = parser.parse_args()
# Step 1: Create metrics.csv in each subdirectory
print("Creating individual metrics.csv files...")
create_subdirectory_metrics(args.input_folder)
# Step 2: Create global_metrics.csv
print("\nAggregating metrics into global_metrics.csv...")
aggregate_metrics(args.input_folder)
if __name__ == "__main__":
main()

164
picotron/checkpoint.py Normal file
View File

@ -0,0 +1,164 @@
import os
import re
import torch
import torch.nn as nn
import torch.distributed as dist
from safetensors import safe_open
import contextlib
from picotron.utils import assert_no_meta_tensors
import picotron.process_group_manager as pgm
@contextlib.contextmanager
def init_model_with_dematerialized_weights(include_buffers: bool = False):
"""
From Accelerate library: https://github.com/huggingface/accelerate/blob/v0.11.0/src/accelerate/big_modeling.py#L254
Context manager that initializes models with empty weights (no memory allocation).
Args:
include_buffers (bool): Whether to also skip buffer initialization.
"""
old_register_parameter = nn.Module.register_parameter
if include_buffers:
old_register_buffer = nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
module._parameters[name] = param_cls(module._parameters[name].to(torch.device("meta")), **kwargs)
def register_empty_buffer(module, name, buffer):
old_register_buffer(module, name, buffer)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(torch.device("meta"))
try:
nn.Module.register_parameter = register_empty_parameter
if include_buffers:
nn.Module.register_buffer = register_empty_buffer
yield
finally:
nn.Module.register_parameter = old_register_parameter
if include_buffers:
nn.Module.register_buffer = old_register_buffer
def initialize_model_with_materialized_weights(model, model_config, checkpoint_path, initialize_weight_tensor_func = None):
"""Initialize model with correct tensor shapes but random weights"""
initialization_manager = InitializationManager(model, model_config)
# convert layer distribution ids to layer_name (using the same naming convention as in safetensors)
model_layer_name_sft_format = initialization_manager.get_layer_names_in_sft_format()
print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {model_layer_name_sft_format}")
safetensors_checkpoint_path = os.path.join(checkpoint_path, "model.safetensors")
with safe_open(safetensors_checkpoint_path, framework="pytorch", device="cpu") as f:
safetensors_names = f.keys()
if len(safetensors_names) > len(model_layer_name_sft_format):
print(f"Warning: Checkpoint has {len(safetensors_names)} layers but model only has {len(model_layer_name_sft_format)} layers.")
# Create state dict with random tensors
state_dict = {}
for sft_name in model_layer_name_sft_format:
# if is_tensor_belongs_to_current_pp_rank(sft_name, model_layer_name_sft_format):
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
tensor = f.get_tensor(sft_name)
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
#TODO: initialize_weight_tensor_func
#TODO: is layernorm init the same way as q k v ?
state_dict[hf_name] = torch.randn_like(tensor)
#TODO: Handle Tensor Parallel splitting if needed
dist.barrier()
model.load_state_dict(state_dict, strict=True, assign=True)
dist.barrier()
assert_no_meta_tensors(model)
return model
class InitializationManager:
def __init__(self, model, model_config):
self.model = model
self.model_config = model_config
def get_layer_names_in_sft_format(self):
"""Get layer names in safetensors format based on model's layer distribution."""
decoder_components = [
"input_layernorm",
"mlp.down_proj",
"mlp.gate_proj",
"mlp.up_proj",
"post_attention_layernorm",
"self_attn.k_proj",
"self_attn.o_proj",
"self_attn.q_proj",
"self_attn.v_proj",
]
# Generate base layer names
layer_names = []
base_names = [f"model.layers.{id}" for id in self.model.layer_distribution]
for layer in base_names:
layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components)
# Add special layers based on pipeline stage
if pgm.process_group_manager.pp_is_first_stage:
layer_names.insert(0, "model.embed_tokens.weight")
elif pgm.process_group_manager.pp_is_last_stage:
layer_names.extend(["model.norm.weight", "lm_head.weight"])
return layer_names
def adjust_tensor_size(self, tensor, name):
"""Resize tensor based on architecture changes."""
if 'attention' not in name:
return tensor
hidden_size = self.model_config.hidden_size
head_dim = hidden_size // self.model_config.num_attention_heads
if 'q_proj.weight' in name:
target_dim = self.model_config.num_attention_heads * head_dim
elif 'k_proj.weight' in name or 'v_proj.weight' in name:
target_dim = self.model_config.num_key_value_heads * head_dim
else:
return tensor
# Adjust tensor size if needed
if tensor.shape[0] != target_dim:
if target_dim > tensor.shape[0]:
pad_tensor = torch.empty(target_dim - tensor.shape[0], tensor.shape[1],
dtype=tensor.dtype, device=tensor.device)
tensor = torch.cat([tensor, pad_tensor], dim=0)
else:
tensor = tensor[:target_dim, :]
return tensor
def convert_safetensors_to_hf_name(self, sft_name):
"""Convert safetensors naming convention to HuggingFace naming convention."""
name_mapping = {
"model.": "",
"layers.": "decoder_layers.",
"embed_tokens": "embedding",
"self_attn.": "attention.",
"o_proj": "out_proj",
"lm_head": "final_proj",
"input_layernorm": "input_layernorm",
"post_attention_layernorm": "post_attention_layernorm",
r'^norm': 'final_norm'
}
result = sft_name
for pattern, replacement in name_mapping.items():
result = re.sub(pattern, replacement, result)
return result
#TODO: Implement and Move save/load checkpoint here
# class CheckpointManager:
# pass

View File

@ -8,9 +8,9 @@ from picotron.pipeline_parallel.pp_communications import pipeline_communicate, b
class PipelineParallel(nn.Module):
def __init__(self, model, config):
super().__init__()
layer_distribution = self.distribute_layers(config.num_hidden_layers)
self.layer_distribution = self.distribute_layers(config.num_hidden_layers)
self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity()
self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in layer_distribution})
self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in self.layer_distribution})
self.final_norm = model.final_norm if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
self.final_proj = model.final_proj if pgm.process_group_manager.pp_is_last_stage else nn.Identity()

View File

@ -4,6 +4,8 @@ import random
import numpy as np
import builtins
import fcntl
import json
import torch.nn as nn
import picotron.process_group_manager as pgm
def print(*args, is_print_rank=True, **kwargs):
@ -32,6 +34,18 @@ def to_readable_format(num, precision=2):
return f"{num / 1e3:.{precision}f}K"
else:
return f"{num:.{precision}f}"
def assert_no_meta_tensors(model):
meta_tensors = []
for name, param in model.named_parameters():
if param.device == torch.device("meta"):
meta_tensors.append(f"Parameter '{name}' with shape {param.shape}")
for name, buffer in model.named_buffers():
if buffer.device == torch.device("meta"):
meta_tensors.append(f"Buffer '{name}' with shape {buffer.shape}")
assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors)
def save_checkpoint(model, optimizer, trained_steps, trained_tokens, out_dir):
"""Save the model/optimizer states/steps to a checkpoint file."""
@ -67,4 +81,4 @@ def load_checkpoint(model, optimizer, out_dir):
raw_model.load_state_dict(checkpoint['model'])
# Load optimizer state
optimizer.load_state_dict(checkpoint['optimizer'])
return checkpoint['trained_steps'], checkpoint['trained_tokens']
return checkpoint['trained_steps'], checkpoint['trained_tokens']

View File

@ -23,12 +23,15 @@ from picotron.context_parallel.context_parallel import apply_context_parallel
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
import picotron.process_group_manager as pgm
from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights
from picotron.data import MicroBatchDataLoader
from picotron.process_group_manager import setup_process_group_manager
from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from picotron.data_parallel.data_parallel import DataParallelBucket
from picotron.model import Llama
import wandb
import lovely_tensors as lt; lt.monkey_patch()
def all_reduce_loss_across_dp_cp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
@ -66,7 +69,7 @@ def train_step(model, data_loader, device):
return acc_loss
if __name__ == "__main__":
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="", help="Path to config file")
args = parser.parse_args()
@ -173,20 +176,22 @@ if __name__ == "__main__":
model_config.max_position_embeddings = SEQ_LEN
start_time = time.time()
model = Llama(config=model_config)
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
dist.barrier()
start_time = time.time()
with init_model_with_dematerialized_weights():
model = Llama(config=model_config)
if pgm.process_group_manager.tp_world_size > 1:
#TODO: remove the initialize_weight_tensor and do it at initialize_model_with_materialized_weights() level
model = apply_tensor_parallel(model, init_method=initialize_weight_tensor)
if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, model_config)
model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/TinyLlama-1.1B-Chat-v0.1", initialize_weight_tensor_func=initialize_weight_tensor)
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
start_time = time.time()
if pgm.process_group_manager.cp_world_size > 1:
model = apply_context_parallel(model)
if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, model_config)
model.to(dtype).to(device)
@ -266,4 +271,4 @@ if __name__ == "__main__":
if is_wandb_rank and USE_WANDB:
wandb.finish()
dist.destroy_process_group()
dist.destroy_process_group()