From daea1fed3f76e8df35487d2bb06c60466b1dfdc0 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Sun, 1 Dec 2024 03:40:56 +0000 Subject: [PATCH] refactor checkpoint --- picotron/checkpoint.py | 50 +++++++++++++++++++++++++++++++++++++++--- picotron/utils.py | 42 +---------------------------------- train.py | 9 +++++--- 3 files changed, 54 insertions(+), 47 deletions(-) diff --git a/picotron/checkpoint.py b/picotron/checkpoint.py index 13ec589..a69f87e 100644 --- a/picotron/checkpoint.py +++ b/picotron/checkpoint.py @@ -159,6 +159,50 @@ class InitializationManager: result = re.sub(pattern, replacement, result) return result -#TODO: Implement and Move save/load checkpoint here -# class CheckpointManager: -# pass +class CheckpointManager: + def __init__(self): + self.tp_rank = pgm.process_group_manager.tp_rank + self.pp_rank = pgm.process_group_manager.pp_rank + self.tp_world_size = pgm.process_group_manager.tp_world_size + self.pp_world_size = pgm.process_group_manager.pp_world_size + self.cp_dp_world_size = pgm.process_group_manager.cp_dp_world_size + self.dp_rank = pgm.process_group_manager.dp_rank + self.cp_rank = pgm.process_group_manager.cp_rank + + def _get_checkpoint_path(self, out_dir): + ckpt_name = f"weights_tp_rank_world_size={self.tp_rank}_{self.tp_world_size}_pp_rank_world_size={self.pp_rank}_{self.pp_world_size}.pth" + return os.path.join(out_dir, ckpt_name) + + def save_checkpoint(self, model, optimizer, trained_steps, trained_tokens, out_dir): + """Save the model/optimizer states/steps to a checkpoint file.""" + path = self._get_checkpoint_path(out_dir) + + # Only DP/CP rank 0 will save the model, the weights are the same across all ranks + if self.dp_rank == 0 and self.cp_rank == 0: + os.makedirs(out_dir, exist_ok=True) + raw_model = model.module if self.cp_dp_world_size > 1 else model + checkpoint = { + 'model': raw_model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'trained_steps': trained_steps, + 'trained_tokens': trained_tokens + } + torch.save(checkpoint, path) + + def load_checkpoint(self, model, optimizer, out_dir): + """Load the model/optimizer states from the latest checkpoint. Assume the topology is the same.""" + path = self._get_checkpoint_path(out_dir) + + if not os.path.exists(path): + raise FileNotFoundError(f"Checkpoint not found at {path}") + + checkpoint = torch.load(path) + + # Load model weights + raw_model = model.module if self.cp_dp_world_size > 1 else model + raw_model.load_state_dict(checkpoint['model']) + + # Load optimizer state + optimizer.load_state_dict(checkpoint['optimizer']) + + return checkpoint['trained_steps'], checkpoint['trained_tokens'] diff --git a/picotron/utils.py b/picotron/utils.py index 30bde4f..996cdeb 100644 --- a/picotron/utils.py +++ b/picotron/utils.py @@ -1,12 +1,8 @@ -import os import torch import random import numpy as np import builtins import fcntl -import json -import torch.nn as nn -import picotron.process_group_manager as pgm def print(*args, is_print_rank=True, **kwargs): """ solves multi-process interleaved print problem """ @@ -45,40 +41,4 @@ def assert_no_meta_tensors(model): if buffer.device == torch.device("meta"): meta_tensors.append(f"Buffer '{name}' with shape {buffer.shape}") - assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors) - -def save_checkpoint(model, optimizer, trained_steps, trained_tokens, out_dir): - """Save the model/optimizer states/steps to a checkpoint file.""" - tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank - tp_world_size, pp_world_size = pgm.process_group_manager.tp_world_size, pgm.process_group_manager.pp_world_size - ckpt_name = f"weights_tp_rank_world_size={tp_rank}_{tp_world_size}_pp_rank_world_size={pp_rank}_{pp_world_size}.pth" - path = os.path.join(out_dir, ckpt_name) - - # Only DP/CP rank 0 will save the model, the weights are the same across all ranks - if pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0: - os.makedirs(out_dir, exist_ok=True) - raw_model = model.module if pgm.process_group_manager.cp_dp_world_size > 1 else model - checkpoint = { - 'model': raw_model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'trained_steps': trained_steps, - 'trained_tokens': trained_tokens - } - torch.save(checkpoint, path) - -def load_checkpoint(model, optimizer, out_dir): - """Load the model/optimizer states from the latest checkpoint. Assume the topology is the same.""" - tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank - tp_world_size, pp_world_size = pgm.process_group_manager.tp_world_size, pgm.process_group_manager.pp_world_size - ckpt_name = f"weights_tp_rank_world_size={tp_rank}_{tp_world_size}_pp_rank_world_size={pp_rank}_{pp_world_size}.pth" - path = os.path.join(out_dir, ckpt_name) - if not os.path.exists(path): - raise FileNotFoundError(f"Checkpoint not found at {path}") - checkpoint = torch.load(path) - - # Load model weights - raw_model = model.module if pgm.process_group_manager.cp_dp_world_size > 1 else model - raw_model.load_state_dict(checkpoint['model']) - # Load optimizer state - optimizer.load_state_dict(checkpoint['optimizer']) - return checkpoint['trained_steps'], checkpoint['trained_tokens'] + assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors) \ No newline at end of file diff --git a/train.py b/train.py index 16bab3b..72d1a64 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,8 @@ from transformers import AutoConfig from picotron.context_parallel.context_parallel import apply_context_parallel from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor import picotron.process_group_manager as pgm -from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint +from picotron.utils import set_all_seed, print, to_readable_format +from picotron.checkpoint import CheckpointManager from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights from picotron.data import MicroBatchDataLoader from picotron.process_group_manager import setup_process_group_manager @@ -213,10 +214,12 @@ if __name__ == "__main__": extra_args = dict(fused=True) if use_fused else dict() optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args) + + checkpoint_manager = CheckpointManager() trained_tokens, step = 0, 0 if LOAD_PATH: - step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH) + step, trained_tokens = checkpoint_manager.load_checkpoint(model, optimizer, LOAD_PATH) dist.barrier() #TODO: Add activation checkpointing @@ -263,7 +266,7 @@ if __name__ == "__main__": "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens}) if step % CHECKPOINT_FREQ == 0: - save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}") + checkpoint_manager.save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}") if step >= TOTAL_TRAIN_STEPS: break