wip: load big model with meta device
This commit is contained in:
parent
099621fd94
commit
5045be87e0
183
extract_metrics.py
Normal file
183
extract_metrics.py
Normal file
@ -0,0 +1,183 @@
|
||||
import re
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
def parse_folder_name(folder_name):
|
||||
dp = re.search(r'dp(\d+)', folder_name)
|
||||
tp = re.search(r'tp(\d+)', folder_name)
|
||||
pp = re.search(r'pp(\d+)', folder_name)
|
||||
mbs = re.search(r'mbs(\d+)', folder_name)
|
||||
ga = re.search(r'ga(\d+)', folder_name)
|
||||
sl = re.search(r'sl(\d+)', folder_name)
|
||||
|
||||
return {
|
||||
'dp': int(dp.group(1)) if dp else None,
|
||||
'tp': int(tp.group(1)) if tp else None,
|
||||
'pp': int(pp.group(1)) if pp else None,
|
||||
'micro_batch_size': int(mbs.group(1)) if mbs else None,
|
||||
'grad_acc': int(ga.group(1)) if ga else None,
|
||||
'seq_len': int(sl.group(1)) if sl else None
|
||||
}
|
||||
|
||||
def from_readable_format(formatted_str):
|
||||
if not isinstance(formatted_str, str):
|
||||
return formatted_str
|
||||
|
||||
# Remove any whitespace and convert to upper case for consistency
|
||||
formatted_str = formatted_str.strip().upper()
|
||||
|
||||
# If it's just a number without suffix, return float
|
||||
try:
|
||||
return float(formatted_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Define multipliers
|
||||
multipliers = {
|
||||
'T': 1e12,
|
||||
'B': 1e9,
|
||||
'M': 1e6,
|
||||
'K': 1e3
|
||||
}
|
||||
|
||||
# Extract number and suffix
|
||||
number = float(formatted_str[:-1])
|
||||
suffix = formatted_str[-1]
|
||||
|
||||
if suffix in multipliers:
|
||||
return number * multipliers[suffix]
|
||||
else:
|
||||
raise ValueError(f"Unknown suffix: {suffix}")
|
||||
|
||||
def parse_log_line(line):
|
||||
tokens_s_gpu_match = re.search(r'Tokens/s/GPU: ([\d.]+[KMBT]?)', line)
|
||||
if tokens_s_gpu_match:
|
||||
value = tokens_s_gpu_match.group(1)
|
||||
return from_readable_format(value)
|
||||
return None
|
||||
|
||||
def process_file(filepath):
|
||||
tokens_s_gpu_values = []
|
||||
with open(filepath, 'r') as f:
|
||||
for line in f:
|
||||
if '[default0]:[rank 0]' in line:
|
||||
tokens_s_gpu = parse_log_line(line)
|
||||
if tokens_s_gpu is not None:
|
||||
tokens_s_gpu_values.append(tokens_s_gpu)
|
||||
|
||||
return int(round(np.mean(tokens_s_gpu_values))) if tokens_s_gpu_values else None
|
||||
|
||||
def write_csv(data, output_filepath):
|
||||
if not data:
|
||||
return
|
||||
|
||||
fieldnames = ['run_name', 'status', 'dp', 'tp', 'pp', 'micro_batch_size', 'grad_acc', 'seq_len', 'avg_tokens_s_gpu']
|
||||
with open(output_filepath, 'w', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerow(data)
|
||||
|
||||
def read_status(status_file):
|
||||
try:
|
||||
with open(status_file, 'r') as f:
|
||||
return f.read().strip()
|
||||
except:
|
||||
return None
|
||||
|
||||
def create_subdirectory_metrics(input_folder):
|
||||
"""Create metrics.csv files in each subdirectory"""
|
||||
pattern = os.path.join(input_folder, '**/*.out')
|
||||
out_files = glob.glob(pattern, recursive=True)
|
||||
|
||||
print(f"Found {len(out_files)} .out files")
|
||||
|
||||
processed_dirs = []
|
||||
for file_path in out_files:
|
||||
dir_path = os.path.dirname(file_path)
|
||||
dir_name = os.path.basename(dir_path)
|
||||
output_csv = os.path.join(dir_path, 'metrics.csv')
|
||||
|
||||
params = parse_folder_name(dir_name)
|
||||
avg_tokens_s_gpu = process_file(file_path)
|
||||
status = read_status(os.path.join(dir_path, 'status.txt'))
|
||||
|
||||
params['run_name'] = dir_name
|
||||
write_csv(params, output_csv)
|
||||
|
||||
if status is not None:
|
||||
params['status'] = status
|
||||
write_csv(params, output_csv)
|
||||
|
||||
if avg_tokens_s_gpu is not None:
|
||||
params['avg_tokens_s_gpu'] = avg_tokens_s_gpu
|
||||
write_csv(params, output_csv)
|
||||
processed_dirs.append(dir_path)
|
||||
print(f"Processed {file_path} -> Created metrics.csv")
|
||||
|
||||
return processed_dirs
|
||||
|
||||
def aggregate_metrics(input_folder):
|
||||
"""Create global_metrics.csv from all subdirectory metrics"""
|
||||
top_level_dir = glob.glob(input_folder + '/*')
|
||||
|
||||
for top_dir_path in top_level_dir:
|
||||
subdirs = glob.glob(top_dir_path + '/*')
|
||||
|
||||
aggregated_data = []
|
||||
|
||||
for subdir_path in subdirs:
|
||||
metrics_file = os.path.join(subdir_path, 'metrics.csv')
|
||||
status_file = os.path.join(subdir_path, 'status.txt')
|
||||
|
||||
folder_name = os.path.basename(subdir_path)
|
||||
|
||||
data = {
|
||||
'run_name': folder_name,
|
||||
'status': read_status(status_file),
|
||||
**parse_folder_name(folder_name) # Unpack the parsed parameters
|
||||
}
|
||||
|
||||
# If metrics.csv exists, read the avg_tokens_s_gpu from it
|
||||
if os.path.exists(metrics_file):
|
||||
try:
|
||||
with open(metrics_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
metrics_data = next(reader)
|
||||
data['avg_tokens_s_gpu'] = int(metrics_data['avg_tokens_s_gpu'])
|
||||
except:
|
||||
data['avg_tokens_s_gpu'] = -1
|
||||
else:
|
||||
data['avg_tokens_s_gpu'] = -1
|
||||
|
||||
aggregated_data.append(data)
|
||||
|
||||
# Write global metrics file
|
||||
output_file = os.path.join(top_dir_path, 'global_metrics.csv')
|
||||
fieldnames = ['run_name', 'status', 'dp', 'tp', 'pp', 'micro_batch_size',
|
||||
'grad_acc', 'seq_len', 'avg_tokens_s_gpu']
|
||||
|
||||
with open(output_file, 'w', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(aggregated_data)
|
||||
|
||||
print(f"Created global_metrics.csv with {len(aggregated_data)} entries")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Process log files and create metrics CSVs')
|
||||
parser.add_argument('input_folder', help='Path to the top-level folder containing experiment subfolders')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Step 1: Create metrics.csv in each subdirectory
|
||||
print("Creating individual metrics.csv files...")
|
||||
create_subdirectory_metrics(args.input_folder)
|
||||
|
||||
# Step 2: Create global_metrics.csv
|
||||
print("\nAggregating metrics into global_metrics.csv...")
|
||||
aggregate_metrics(args.input_folder)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
164
picotron/checkpoint.py
Normal file
164
picotron/checkpoint.py
Normal file
@ -0,0 +1,164 @@
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from safetensors import safe_open
|
||||
import contextlib
|
||||
|
||||
from picotron.utils import assert_no_meta_tensors
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
@contextlib.contextmanager
|
||||
def init_model_with_dematerialized_weights(include_buffers: bool = False):
|
||||
"""
|
||||
From Accelerate library: https://github.com/huggingface/accelerate/blob/v0.11.0/src/accelerate/big_modeling.py#L254
|
||||
Context manager that initializes models with empty weights (no memory allocation).
|
||||
|
||||
Args:
|
||||
include_buffers (bool): Whether to also skip buffer initialization.
|
||||
"""
|
||||
old_register_parameter = nn.Module.register_parameter
|
||||
if include_buffers:
|
||||
old_register_buffer = nn.Module.register_buffer
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
module._parameters[name] = param_cls(module._parameters[name].to(torch.device("meta")), **kwargs)
|
||||
|
||||
def register_empty_buffer(module, name, buffer):
|
||||
old_register_buffer(module, name, buffer)
|
||||
if buffer is not None:
|
||||
module._buffers[name] = module._buffers[name].to(torch.device("meta"))
|
||||
|
||||
try:
|
||||
nn.Module.register_parameter = register_empty_parameter
|
||||
if include_buffers:
|
||||
nn.Module.register_buffer = register_empty_buffer
|
||||
yield
|
||||
finally:
|
||||
nn.Module.register_parameter = old_register_parameter
|
||||
if include_buffers:
|
||||
nn.Module.register_buffer = old_register_buffer
|
||||
|
||||
|
||||
def initialize_model_with_materialized_weights(model, model_config, checkpoint_path, initialize_weight_tensor_func = None):
|
||||
"""Initialize model with correct tensor shapes but random weights"""
|
||||
|
||||
initialization_manager = InitializationManager(model, model_config)
|
||||
|
||||
# convert layer distribution ids to layer_name (using the same naming convention as in safetensors)
|
||||
model_layer_name_sft_format = initialization_manager.get_layer_names_in_sft_format()
|
||||
print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {model_layer_name_sft_format}")
|
||||
|
||||
safetensors_checkpoint_path = os.path.join(checkpoint_path, "model.safetensors")
|
||||
with safe_open(safetensors_checkpoint_path, framework="pytorch", device="cpu") as f:
|
||||
safetensors_names = f.keys()
|
||||
|
||||
if len(safetensors_names) > len(model_layer_name_sft_format):
|
||||
print(f"Warning: Checkpoint has {len(safetensors_names)} layers but model only has {len(model_layer_name_sft_format)} layers.")
|
||||
|
||||
# Create state dict with random tensors
|
||||
state_dict = {}
|
||||
for sft_name in model_layer_name_sft_format:
|
||||
# if is_tensor_belongs_to_current_pp_rank(sft_name, model_layer_name_sft_format):
|
||||
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
|
||||
tensor = f.get_tensor(sft_name)
|
||||
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
|
||||
|
||||
#TODO: initialize_weight_tensor_func
|
||||
#TODO: is layernorm init the same way as q k v ?
|
||||
state_dict[hf_name] = torch.randn_like(tensor)
|
||||
|
||||
#TODO: Handle Tensor Parallel splitting if needed
|
||||
|
||||
dist.barrier()
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
dist.barrier()
|
||||
assert_no_meta_tensors(model)
|
||||
return model
|
||||
|
||||
class InitializationManager:
|
||||
def __init__(self, model, model_config):
|
||||
self.model = model
|
||||
self.model_config = model_config
|
||||
|
||||
def get_layer_names_in_sft_format(self):
|
||||
"""Get layer names in safetensors format based on model's layer distribution."""
|
||||
decoder_components = [
|
||||
"input_layernorm",
|
||||
"mlp.down_proj",
|
||||
"mlp.gate_proj",
|
||||
"mlp.up_proj",
|
||||
"post_attention_layernorm",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.o_proj",
|
||||
"self_attn.q_proj",
|
||||
"self_attn.v_proj",
|
||||
]
|
||||
|
||||
# Generate base layer names
|
||||
layer_names = []
|
||||
base_names = [f"model.layers.{id}" for id in self.model.layer_distribution]
|
||||
for layer in base_names:
|
||||
layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components)
|
||||
|
||||
# Add special layers based on pipeline stage
|
||||
if pgm.process_group_manager.pp_is_first_stage:
|
||||
layer_names.insert(0, "model.embed_tokens.weight")
|
||||
elif pgm.process_group_manager.pp_is_last_stage:
|
||||
layer_names.extend(["model.norm.weight", "lm_head.weight"])
|
||||
|
||||
return layer_names
|
||||
|
||||
def adjust_tensor_size(self, tensor, name):
|
||||
"""Resize tensor based on architecture changes."""
|
||||
if 'attention' not in name:
|
||||
return tensor
|
||||
|
||||
hidden_size = self.model_config.hidden_size
|
||||
head_dim = hidden_size // self.model_config.num_attention_heads
|
||||
|
||||
if 'q_proj.weight' in name:
|
||||
target_dim = self.model_config.num_attention_heads * head_dim
|
||||
elif 'k_proj.weight' in name or 'v_proj.weight' in name:
|
||||
target_dim = self.model_config.num_key_value_heads * head_dim
|
||||
else:
|
||||
return tensor
|
||||
|
||||
# Adjust tensor size if needed
|
||||
if tensor.shape[0] != target_dim:
|
||||
if target_dim > tensor.shape[0]:
|
||||
pad_tensor = torch.empty(target_dim - tensor.shape[0], tensor.shape[1],
|
||||
dtype=tensor.dtype, device=tensor.device)
|
||||
tensor = torch.cat([tensor, pad_tensor], dim=0)
|
||||
else:
|
||||
tensor = tensor[:target_dim, :]
|
||||
|
||||
return tensor
|
||||
|
||||
def convert_safetensors_to_hf_name(self, sft_name):
|
||||
"""Convert safetensors naming convention to HuggingFace naming convention."""
|
||||
name_mapping = {
|
||||
"model.": "",
|
||||
"layers.": "decoder_layers.",
|
||||
"embed_tokens": "embedding",
|
||||
"self_attn.": "attention.",
|
||||
"o_proj": "out_proj",
|
||||
"lm_head": "final_proj",
|
||||
"input_layernorm": "input_layernorm",
|
||||
"post_attention_layernorm": "post_attention_layernorm",
|
||||
r'^norm': 'final_norm'
|
||||
}
|
||||
|
||||
result = sft_name
|
||||
for pattern, replacement in name_mapping.items():
|
||||
result = re.sub(pattern, replacement, result)
|
||||
return result
|
||||
|
||||
#TODO: Implement and Move save/load checkpoint here
|
||||
# class CheckpointManager:
|
||||
# pass
|
||||
@ -8,9 +8,9 @@ from picotron.pipeline_parallel.pp_communications import pipeline_communicate, b
|
||||
class PipelineParallel(nn.Module):
|
||||
def __init__(self, model, config):
|
||||
super().__init__()
|
||||
layer_distribution = self.distribute_layers(config.num_hidden_layers)
|
||||
self.layer_distribution = self.distribute_layers(config.num_hidden_layers)
|
||||
self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity()
|
||||
self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in layer_distribution})
|
||||
self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in self.layer_distribution})
|
||||
self.final_norm = model.final_norm if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
|
||||
self.final_proj = model.final_proj if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
|
||||
|
||||
|
||||
@ -4,6 +4,8 @@ import random
|
||||
import numpy as np
|
||||
import builtins
|
||||
import fcntl
|
||||
import json
|
||||
import torch.nn as nn
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
def print(*args, is_print_rank=True, **kwargs):
|
||||
@ -32,6 +34,18 @@ def to_readable_format(num, precision=2):
|
||||
return f"{num / 1e3:.{precision}f}K"
|
||||
else:
|
||||
return f"{num:.{precision}f}"
|
||||
|
||||
def assert_no_meta_tensors(model):
|
||||
meta_tensors = []
|
||||
for name, param in model.named_parameters():
|
||||
if param.device == torch.device("meta"):
|
||||
meta_tensors.append(f"Parameter '{name}' with shape {param.shape}")
|
||||
|
||||
for name, buffer in model.named_buffers():
|
||||
if buffer.device == torch.device("meta"):
|
||||
meta_tensors.append(f"Buffer '{name}' with shape {buffer.shape}")
|
||||
|
||||
assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors)
|
||||
|
||||
def save_checkpoint(model, optimizer, trained_steps, trained_tokens, out_dir):
|
||||
"""Save the model/optimizer states/steps to a checkpoint file."""
|
||||
@ -67,4 +81,4 @@ def load_checkpoint(model, optimizer, out_dir):
|
||||
raw_model.load_state_dict(checkpoint['model'])
|
||||
# Load optimizer state
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
return checkpoint['trained_steps'], checkpoint['trained_tokens']
|
||||
return checkpoint['trained_steps'], checkpoint['trained_tokens']
|
||||
|
||||
25
train.py
25
train.py
@ -23,12 +23,15 @@ from picotron.context_parallel.context_parallel import apply_context_parallel
|
||||
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
|
||||
from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights
|
||||
from picotron.data import MicroBatchDataLoader
|
||||
from picotron.process_group_manager import setup_process_group_manager
|
||||
from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from picotron.data_parallel.data_parallel import DataParallelBucket
|
||||
from picotron.model import Llama
|
||||
import wandb
|
||||
import lovely_tensors as lt; lt.monkey_patch()
|
||||
|
||||
|
||||
def all_reduce_loss_across_dp_cp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
@ -66,7 +69,7 @@ def train_step(model, data_loader, device):
|
||||
|
||||
return acc_loss
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="", help="Path to config file")
|
||||
args = parser.parse_args()
|
||||
@ -173,20 +176,22 @@ if __name__ == "__main__":
|
||||
model_config.max_position_embeddings = SEQ_LEN
|
||||
|
||||
start_time = time.time()
|
||||
model = Llama(config=model_config)
|
||||
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
dist.barrier()
|
||||
|
||||
start_time = time.time()
|
||||
with init_model_with_dematerialized_weights():
|
||||
model = Llama(config=model_config)
|
||||
|
||||
if pgm.process_group_manager.tp_world_size > 1:
|
||||
#TODO: remove the initialize_weight_tensor and do it at initialize_model_with_materialized_weights() level
|
||||
model = apply_tensor_parallel(model, init_method=initialize_weight_tensor)
|
||||
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
model = PipelineParallel(model, model_config)
|
||||
|
||||
model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/TinyLlama-1.1B-Chat-v0.1", initialize_weight_tensor_func=initialize_weight_tensor)
|
||||
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
start_time = time.time()
|
||||
|
||||
if pgm.process_group_manager.cp_world_size > 1:
|
||||
model = apply_context_parallel(model)
|
||||
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
model = PipelineParallel(model, model_config)
|
||||
|
||||
model.to(dtype).to(device)
|
||||
|
||||
@ -266,4 +271,4 @@ if __name__ == "__main__":
|
||||
if is_wandb_rank and USE_WANDB:
|
||||
wandb.finish()
|
||||
|
||||
dist.destroy_process_group()
|
||||
dist.destroy_process_group()
|
||||
Loading…
Reference in New Issue
Block a user