refactor checkpoint

This commit is contained in:
ferdinand.mom 2024-12-01 03:40:56 +00:00
parent 5045be87e0
commit daea1fed3f
3 changed files with 54 additions and 47 deletions

View File

@ -159,6 +159,50 @@ class InitializationManager:
result = re.sub(pattern, replacement, result)
return result
#TODO: Implement and Move save/load checkpoint here
# class CheckpointManager:
# pass
class CheckpointManager:
def __init__(self):
self.tp_rank = pgm.process_group_manager.tp_rank
self.pp_rank = pgm.process_group_manager.pp_rank
self.tp_world_size = pgm.process_group_manager.tp_world_size
self.pp_world_size = pgm.process_group_manager.pp_world_size
self.cp_dp_world_size = pgm.process_group_manager.cp_dp_world_size
self.dp_rank = pgm.process_group_manager.dp_rank
self.cp_rank = pgm.process_group_manager.cp_rank
def _get_checkpoint_path(self, out_dir):
ckpt_name = f"weights_tp_rank_world_size={self.tp_rank}_{self.tp_world_size}_pp_rank_world_size={self.pp_rank}_{self.pp_world_size}.pth"
return os.path.join(out_dir, ckpt_name)
def save_checkpoint(self, model, optimizer, trained_steps, trained_tokens, out_dir):
"""Save the model/optimizer states/steps to a checkpoint file."""
path = self._get_checkpoint_path(out_dir)
# Only DP/CP rank 0 will save the model, the weights are the same across all ranks
if self.dp_rank == 0 and self.cp_rank == 0:
os.makedirs(out_dir, exist_ok=True)
raw_model = model.module if self.cp_dp_world_size > 1 else model
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'trained_steps': trained_steps,
'trained_tokens': trained_tokens
}
torch.save(checkpoint, path)
def load_checkpoint(self, model, optimizer, out_dir):
"""Load the model/optimizer states from the latest checkpoint. Assume the topology is the same."""
path = self._get_checkpoint_path(out_dir)
if not os.path.exists(path):
raise FileNotFoundError(f"Checkpoint not found at {path}")
checkpoint = torch.load(path)
# Load model weights
raw_model = model.module if self.cp_dp_world_size > 1 else model
raw_model.load_state_dict(checkpoint['model'])
# Load optimizer state
optimizer.load_state_dict(checkpoint['optimizer'])
return checkpoint['trained_steps'], checkpoint['trained_tokens']

View File

@ -1,12 +1,8 @@
import os
import torch
import random
import numpy as np
import builtins
import fcntl
import json
import torch.nn as nn
import picotron.process_group_manager as pgm
def print(*args, is_print_rank=True, **kwargs):
""" solves multi-process interleaved print problem """
@ -45,40 +41,4 @@ def assert_no_meta_tensors(model):
if buffer.device == torch.device("meta"):
meta_tensors.append(f"Buffer '{name}' with shape {buffer.shape}")
assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors)
def save_checkpoint(model, optimizer, trained_steps, trained_tokens, out_dir):
"""Save the model/optimizer states/steps to a checkpoint file."""
tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank
tp_world_size, pp_world_size = pgm.process_group_manager.tp_world_size, pgm.process_group_manager.pp_world_size
ckpt_name = f"weights_tp_rank_world_size={tp_rank}_{tp_world_size}_pp_rank_world_size={pp_rank}_{pp_world_size}.pth"
path = os.path.join(out_dir, ckpt_name)
# Only DP/CP rank 0 will save the model, the weights are the same across all ranks
if pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0:
os.makedirs(out_dir, exist_ok=True)
raw_model = model.module if pgm.process_group_manager.cp_dp_world_size > 1 else model
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'trained_steps': trained_steps,
'trained_tokens': trained_tokens
}
torch.save(checkpoint, path)
def load_checkpoint(model, optimizer, out_dir):
"""Load the model/optimizer states from the latest checkpoint. Assume the topology is the same."""
tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank
tp_world_size, pp_world_size = pgm.process_group_manager.tp_world_size, pgm.process_group_manager.pp_world_size
ckpt_name = f"weights_tp_rank_world_size={tp_rank}_{tp_world_size}_pp_rank_world_size={pp_rank}_{pp_world_size}.pth"
path = os.path.join(out_dir, ckpt_name)
if not os.path.exists(path):
raise FileNotFoundError(f"Checkpoint not found at {path}")
checkpoint = torch.load(path)
# Load model weights
raw_model = model.module if pgm.process_group_manager.cp_dp_world_size > 1 else model
raw_model.load_state_dict(checkpoint['model'])
# Load optimizer state
optimizer.load_state_dict(checkpoint['optimizer'])
return checkpoint['trained_steps'], checkpoint['trained_tokens']
assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors)

View File

@ -22,7 +22,8 @@ from transformers import AutoConfig
from picotron.context_parallel.context_parallel import apply_context_parallel
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
import picotron.process_group_manager as pgm
from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
from picotron.utils import set_all_seed, print, to_readable_format
from picotron.checkpoint import CheckpointManager
from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights
from picotron.data import MicroBatchDataLoader
from picotron.process_group_manager import setup_process_group_manager
@ -213,10 +214,12 @@ if __name__ == "__main__":
extra_args = dict(fused=True) if use_fused else dict()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args)
checkpoint_manager = CheckpointManager()
trained_tokens, step = 0, 0
if LOAD_PATH:
step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH)
step, trained_tokens = checkpoint_manager.load_checkpoint(model, optimizer, LOAD_PATH)
dist.barrier()
#TODO: Add activation checkpointing
@ -263,7 +266,7 @@ if __name__ == "__main__":
"memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
if step % CHECKPOINT_FREQ == 0:
save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")
checkpoint_manager.save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")
if step >= TOTAL_TRAIN_STEPS:
break