From 5045be87e0adf4dd64ac635901cbcf1405a78db4 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 29 Nov 2024 16:38:42 +0000 Subject: [PATCH] wip: load big model with meta device --- extract_metrics.py | 183 ++++++++++++++++++ picotron/checkpoint.py | 164 ++++++++++++++++ .../pipeline_parallel/pipeline_parallel.py | 4 +- picotron/utils.py | 16 +- train.py | 25 ++- 5 files changed, 379 insertions(+), 13 deletions(-) create mode 100644 extract_metrics.py create mode 100644 picotron/checkpoint.py diff --git a/extract_metrics.py b/extract_metrics.py new file mode 100644 index 0000000..b26b08e --- /dev/null +++ b/extract_metrics.py @@ -0,0 +1,183 @@ +import re +import csv +import glob +import os +import argparse +import numpy as np + +def parse_folder_name(folder_name): + dp = re.search(r'dp(\d+)', folder_name) + tp = re.search(r'tp(\d+)', folder_name) + pp = re.search(r'pp(\d+)', folder_name) + mbs = re.search(r'mbs(\d+)', folder_name) + ga = re.search(r'ga(\d+)', folder_name) + sl = re.search(r'sl(\d+)', folder_name) + + return { + 'dp': int(dp.group(1)) if dp else None, + 'tp': int(tp.group(1)) if tp else None, + 'pp': int(pp.group(1)) if pp else None, + 'micro_batch_size': int(mbs.group(1)) if mbs else None, + 'grad_acc': int(ga.group(1)) if ga else None, + 'seq_len': int(sl.group(1)) if sl else None + } + +def from_readable_format(formatted_str): + if not isinstance(formatted_str, str): + return formatted_str + + # Remove any whitespace and convert to upper case for consistency + formatted_str = formatted_str.strip().upper() + + # If it's just a number without suffix, return float + try: + return float(formatted_str) + except ValueError: + pass + + # Define multipliers + multipliers = { + 'T': 1e12, + 'B': 1e9, + 'M': 1e6, + 'K': 1e3 + } + + # Extract number and suffix + number = float(formatted_str[:-1]) + suffix = formatted_str[-1] + + if suffix in multipliers: + return number * multipliers[suffix] + else: + raise ValueError(f"Unknown suffix: {suffix}") + +def parse_log_line(line): + tokens_s_gpu_match = re.search(r'Tokens/s/GPU: ([\d.]+[KMBT]?)', line) + if tokens_s_gpu_match: + value = tokens_s_gpu_match.group(1) + return from_readable_format(value) + return None + +def process_file(filepath): + tokens_s_gpu_values = [] + with open(filepath, 'r') as f: + for line in f: + if '[default0]:[rank 0]' in line: + tokens_s_gpu = parse_log_line(line) + if tokens_s_gpu is not None: + tokens_s_gpu_values.append(tokens_s_gpu) + + return int(round(np.mean(tokens_s_gpu_values))) if tokens_s_gpu_values else None + +def write_csv(data, output_filepath): + if not data: + return + + fieldnames = ['run_name', 'status', 'dp', 'tp', 'pp', 'micro_batch_size', 'grad_acc', 'seq_len', 'avg_tokens_s_gpu'] + with open(output_filepath, 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerow(data) + +def read_status(status_file): + try: + with open(status_file, 'r') as f: + return f.read().strip() + except: + return None + +def create_subdirectory_metrics(input_folder): + """Create metrics.csv files in each subdirectory""" + pattern = os.path.join(input_folder, '**/*.out') + out_files = glob.glob(pattern, recursive=True) + + print(f"Found {len(out_files)} .out files") + + processed_dirs = [] + for file_path in out_files: + dir_path = os.path.dirname(file_path) + dir_name = os.path.basename(dir_path) + output_csv = os.path.join(dir_path, 'metrics.csv') + + params = parse_folder_name(dir_name) + avg_tokens_s_gpu = process_file(file_path) + status = read_status(os.path.join(dir_path, 'status.txt')) + + params['run_name'] = dir_name + write_csv(params, output_csv) + + if status is not None: + params['status'] = status + write_csv(params, output_csv) + + if avg_tokens_s_gpu is not None: + params['avg_tokens_s_gpu'] = avg_tokens_s_gpu + write_csv(params, output_csv) + processed_dirs.append(dir_path) + print(f"Processed {file_path} -> Created metrics.csv") + + return processed_dirs + +def aggregate_metrics(input_folder): + """Create global_metrics.csv from all subdirectory metrics""" + top_level_dir = glob.glob(input_folder + '/*') + + for top_dir_path in top_level_dir: + subdirs = glob.glob(top_dir_path + '/*') + + aggregated_data = [] + + for subdir_path in subdirs: + metrics_file = os.path.join(subdir_path, 'metrics.csv') + status_file = os.path.join(subdir_path, 'status.txt') + + folder_name = os.path.basename(subdir_path) + + data = { + 'run_name': folder_name, + 'status': read_status(status_file), + **parse_folder_name(folder_name) # Unpack the parsed parameters + } + + # If metrics.csv exists, read the avg_tokens_s_gpu from it + if os.path.exists(metrics_file): + try: + with open(metrics_file, 'r') as f: + reader = csv.DictReader(f) + metrics_data = next(reader) + data['avg_tokens_s_gpu'] = int(metrics_data['avg_tokens_s_gpu']) + except: + data['avg_tokens_s_gpu'] = -1 + else: + data['avg_tokens_s_gpu'] = -1 + + aggregated_data.append(data) + + # Write global metrics file + output_file = os.path.join(top_dir_path, 'global_metrics.csv') + fieldnames = ['run_name', 'status', 'dp', 'tp', 'pp', 'micro_batch_size', + 'grad_acc', 'seq_len', 'avg_tokens_s_gpu'] + + with open(output_file, 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(aggregated_data) + + print(f"Created global_metrics.csv with {len(aggregated_data)} entries") + +def main(): + parser = argparse.ArgumentParser(description='Process log files and create metrics CSVs') + parser.add_argument('input_folder', help='Path to the top-level folder containing experiment subfolders') + args = parser.parse_args() + + # Step 1: Create metrics.csv in each subdirectory + print("Creating individual metrics.csv files...") + create_subdirectory_metrics(args.input_folder) + + # Step 2: Create global_metrics.csv + print("\nAggregating metrics into global_metrics.csv...") + aggregate_metrics(args.input_folder) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/picotron/checkpoint.py b/picotron/checkpoint.py new file mode 100644 index 0000000..13ec589 --- /dev/null +++ b/picotron/checkpoint.py @@ -0,0 +1,164 @@ +import os +import re +import torch +import torch.nn as nn +import torch.distributed as dist +from safetensors import safe_open +import contextlib + +from picotron.utils import assert_no_meta_tensors +import picotron.process_group_manager as pgm + +@contextlib.contextmanager +def init_model_with_dematerialized_weights(include_buffers: bool = False): + """ + From Accelerate library: https://github.com/huggingface/accelerate/blob/v0.11.0/src/accelerate/big_modeling.py#L254 + Context manager that initializes models with empty weights (no memory allocation). + + Args: + include_buffers (bool): Whether to also skip buffer initialization. + """ + old_register_parameter = nn.Module.register_parameter + if include_buffers: + old_register_buffer = nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + module._parameters[name] = param_cls(module._parameters[name].to(torch.device("meta")), **kwargs) + + def register_empty_buffer(module, name, buffer): + old_register_buffer(module, name, buffer) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(torch.device("meta")) + + try: + nn.Module.register_parameter = register_empty_parameter + if include_buffers: + nn.Module.register_buffer = register_empty_buffer + yield + finally: + nn.Module.register_parameter = old_register_parameter + if include_buffers: + nn.Module.register_buffer = old_register_buffer + + +def initialize_model_with_materialized_weights(model, model_config, checkpoint_path, initialize_weight_tensor_func = None): + """Initialize model with correct tensor shapes but random weights""" + + initialization_manager = InitializationManager(model, model_config) + + # convert layer distribution ids to layer_name (using the same naming convention as in safetensors) + model_layer_name_sft_format = initialization_manager.get_layer_names_in_sft_format() + print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {model_layer_name_sft_format}") + + safetensors_checkpoint_path = os.path.join(checkpoint_path, "model.safetensors") + with safe_open(safetensors_checkpoint_path, framework="pytorch", device="cpu") as f: + safetensors_names = f.keys() + + if len(safetensors_names) > len(model_layer_name_sft_format): + print(f"Warning: Checkpoint has {len(safetensors_names)} layers but model only has {len(model_layer_name_sft_format)} layers.") + + # Create state dict with random tensors + state_dict = {} + for sft_name in model_layer_name_sft_format: + # if is_tensor_belongs_to_current_pp_rank(sft_name, model_layer_name_sft_format): + hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name) + tensor = f.get_tensor(sft_name) + tensor = initialization_manager.adjust_tensor_size(tensor, hf_name) + + #TODO: initialize_weight_tensor_func + #TODO: is layernorm init the same way as q k v ? + state_dict[hf_name] = torch.randn_like(tensor) + + #TODO: Handle Tensor Parallel splitting if needed + + dist.barrier() + model.load_state_dict(state_dict, strict=True, assign=True) + dist.barrier() + assert_no_meta_tensors(model) + return model + +class InitializationManager: + def __init__(self, model, model_config): + self.model = model + self.model_config = model_config + + def get_layer_names_in_sft_format(self): + """Get layer names in safetensors format based on model's layer distribution.""" + decoder_components = [ + "input_layernorm", + "mlp.down_proj", + "mlp.gate_proj", + "mlp.up_proj", + "post_attention_layernorm", + "self_attn.k_proj", + "self_attn.o_proj", + "self_attn.q_proj", + "self_attn.v_proj", + ] + + # Generate base layer names + layer_names = [] + base_names = [f"model.layers.{id}" for id in self.model.layer_distribution] + for layer in base_names: + layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components) + + # Add special layers based on pipeline stage + if pgm.process_group_manager.pp_is_first_stage: + layer_names.insert(0, "model.embed_tokens.weight") + elif pgm.process_group_manager.pp_is_last_stage: + layer_names.extend(["model.norm.weight", "lm_head.weight"]) + + return layer_names + + def adjust_tensor_size(self, tensor, name): + """Resize tensor based on architecture changes.""" + if 'attention' not in name: + return tensor + + hidden_size = self.model_config.hidden_size + head_dim = hidden_size // self.model_config.num_attention_heads + + if 'q_proj.weight' in name: + target_dim = self.model_config.num_attention_heads * head_dim + elif 'k_proj.weight' in name or 'v_proj.weight' in name: + target_dim = self.model_config.num_key_value_heads * head_dim + else: + return tensor + + # Adjust tensor size if needed + if tensor.shape[0] != target_dim: + if target_dim > tensor.shape[0]: + pad_tensor = torch.empty(target_dim - tensor.shape[0], tensor.shape[1], + dtype=tensor.dtype, device=tensor.device) + tensor = torch.cat([tensor, pad_tensor], dim=0) + else: + tensor = tensor[:target_dim, :] + + return tensor + + def convert_safetensors_to_hf_name(self, sft_name): + """Convert safetensors naming convention to HuggingFace naming convention.""" + name_mapping = { + "model.": "", + "layers.": "decoder_layers.", + "embed_tokens": "embedding", + "self_attn.": "attention.", + "o_proj": "out_proj", + "lm_head": "final_proj", + "input_layernorm": "input_layernorm", + "post_attention_layernorm": "post_attention_layernorm", + r'^norm': 'final_norm' + } + + result = sft_name + for pattern, replacement in name_mapping.items(): + result = re.sub(pattern, replacement, result) + return result + +#TODO: Implement and Move save/load checkpoint here +# class CheckpointManager: +# pass diff --git a/picotron/pipeline_parallel/pipeline_parallel.py b/picotron/pipeline_parallel/pipeline_parallel.py index 39574f2..d1c2e2c 100644 --- a/picotron/pipeline_parallel/pipeline_parallel.py +++ b/picotron/pipeline_parallel/pipeline_parallel.py @@ -8,9 +8,9 @@ from picotron.pipeline_parallel.pp_communications import pipeline_communicate, b class PipelineParallel(nn.Module): def __init__(self, model, config): super().__init__() - layer_distribution = self.distribute_layers(config.num_hidden_layers) + self.layer_distribution = self.distribute_layers(config.num_hidden_layers) self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity() - self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in layer_distribution}) + self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in self.layer_distribution}) self.final_norm = model.final_norm if pgm.process_group_manager.pp_is_last_stage else nn.Identity() self.final_proj = model.final_proj if pgm.process_group_manager.pp_is_last_stage else nn.Identity() diff --git a/picotron/utils.py b/picotron/utils.py index 5dd100f..30bde4f 100644 --- a/picotron/utils.py +++ b/picotron/utils.py @@ -4,6 +4,8 @@ import random import numpy as np import builtins import fcntl +import json +import torch.nn as nn import picotron.process_group_manager as pgm def print(*args, is_print_rank=True, **kwargs): @@ -32,6 +34,18 @@ def to_readable_format(num, precision=2): return f"{num / 1e3:.{precision}f}K" else: return f"{num:.{precision}f}" + +def assert_no_meta_tensors(model): + meta_tensors = [] + for name, param in model.named_parameters(): + if param.device == torch.device("meta"): + meta_tensors.append(f"Parameter '{name}' with shape {param.shape}") + + for name, buffer in model.named_buffers(): + if buffer.device == torch.device("meta"): + meta_tensors.append(f"Buffer '{name}' with shape {buffer.shape}") + + assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors) def save_checkpoint(model, optimizer, trained_steps, trained_tokens, out_dir): """Save the model/optimizer states/steps to a checkpoint file.""" @@ -67,4 +81,4 @@ def load_checkpoint(model, optimizer, out_dir): raw_model.load_state_dict(checkpoint['model']) # Load optimizer state optimizer.load_state_dict(checkpoint['optimizer']) - return checkpoint['trained_steps'], checkpoint['trained_tokens'] \ No newline at end of file + return checkpoint['trained_steps'], checkpoint['trained_tokens'] diff --git a/train.py b/train.py index 10443f7..16bab3b 100644 --- a/train.py +++ b/train.py @@ -23,12 +23,15 @@ from picotron.context_parallel.context_parallel import apply_context_parallel from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor import picotron.process_group_manager as pgm from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint +from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights from picotron.data import MicroBatchDataLoader from picotron.process_group_manager import setup_process_group_manager from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel from picotron.data_parallel.data_parallel import DataParallelBucket from picotron.model import Llama import wandb +import lovely_tensors as lt; lt.monkey_patch() + def all_reduce_loss_across_dp_cp_ranks(loss, device): reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) @@ -66,7 +69,7 @@ def train_step(model, data_loader, device): return acc_loss -if __name__ == "__main__": +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="", help="Path to config file") args = parser.parse_args() @@ -173,20 +176,22 @@ if __name__ == "__main__": model_config.max_position_embeddings = SEQ_LEN start_time = time.time() - model = Llama(config=model_config) - print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank) - dist.barrier() - - start_time = time.time() + with init_model_with_dematerialized_weights(): + model = Llama(config=model_config) if pgm.process_group_manager.tp_world_size > 1: + #TODO: remove the initialize_weight_tensor and do it at initialize_model_with_materialized_weights() level model = apply_tensor_parallel(model, init_method=initialize_weight_tensor) + + if pgm.process_group_manager.pp_world_size > 1: + model = PipelineParallel(model, model_config) + + model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/TinyLlama-1.1B-Chat-v0.1", initialize_weight_tensor_func=initialize_weight_tensor) + print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank) + start_time = time.time() if pgm.process_group_manager.cp_world_size > 1: model = apply_context_parallel(model) - - if pgm.process_group_manager.pp_world_size > 1: - model = PipelineParallel(model, model_config) model.to(dtype).to(device) @@ -266,4 +271,4 @@ if __name__ == "__main__": if is_wandb_rank and USE_WANDB: wandb.finish() - dist.destroy_process_group() + dist.destroy_process_group() \ No newline at end of file