refactor checkpoint

This commit is contained in:
ferdinand.mom 2024-12-01 03:40:56 +00:00
parent 5045be87e0
commit daea1fed3f
3 changed files with 54 additions and 47 deletions

View File

@ -159,6 +159,50 @@ class InitializationManager:
result = re.sub(pattern, replacement, result) result = re.sub(pattern, replacement, result)
return result return result
#TODO: Implement and Move save/load checkpoint here class CheckpointManager:
# class CheckpointManager: def __init__(self):
# pass self.tp_rank = pgm.process_group_manager.tp_rank
self.pp_rank = pgm.process_group_manager.pp_rank
self.tp_world_size = pgm.process_group_manager.tp_world_size
self.pp_world_size = pgm.process_group_manager.pp_world_size
self.cp_dp_world_size = pgm.process_group_manager.cp_dp_world_size
self.dp_rank = pgm.process_group_manager.dp_rank
self.cp_rank = pgm.process_group_manager.cp_rank
def _get_checkpoint_path(self, out_dir):
ckpt_name = f"weights_tp_rank_world_size={self.tp_rank}_{self.tp_world_size}_pp_rank_world_size={self.pp_rank}_{self.pp_world_size}.pth"
return os.path.join(out_dir, ckpt_name)
def save_checkpoint(self, model, optimizer, trained_steps, trained_tokens, out_dir):
"""Save the model/optimizer states/steps to a checkpoint file."""
path = self._get_checkpoint_path(out_dir)
# Only DP/CP rank 0 will save the model, the weights are the same across all ranks
if self.dp_rank == 0 and self.cp_rank == 0:
os.makedirs(out_dir, exist_ok=True)
raw_model = model.module if self.cp_dp_world_size > 1 else model
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'trained_steps': trained_steps,
'trained_tokens': trained_tokens
}
torch.save(checkpoint, path)
def load_checkpoint(self, model, optimizer, out_dir):
"""Load the model/optimizer states from the latest checkpoint. Assume the topology is the same."""
path = self._get_checkpoint_path(out_dir)
if not os.path.exists(path):
raise FileNotFoundError(f"Checkpoint not found at {path}")
checkpoint = torch.load(path)
# Load model weights
raw_model = model.module if self.cp_dp_world_size > 1 else model
raw_model.load_state_dict(checkpoint['model'])
# Load optimizer state
optimizer.load_state_dict(checkpoint['optimizer'])
return checkpoint['trained_steps'], checkpoint['trained_tokens']

View File

@ -1,12 +1,8 @@
import os
import torch import torch
import random import random
import numpy as np import numpy as np
import builtins import builtins
import fcntl import fcntl
import json
import torch.nn as nn
import picotron.process_group_manager as pgm
def print(*args, is_print_rank=True, **kwargs): def print(*args, is_print_rank=True, **kwargs):
""" solves multi-process interleaved print problem """ """ solves multi-process interleaved print problem """
@ -45,40 +41,4 @@ def assert_no_meta_tensors(model):
if buffer.device == torch.device("meta"): if buffer.device == torch.device("meta"):
meta_tensors.append(f"Buffer '{name}' with shape {buffer.shape}") meta_tensors.append(f"Buffer '{name}' with shape {buffer.shape}")
assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors) assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors)
def save_checkpoint(model, optimizer, trained_steps, trained_tokens, out_dir):
"""Save the model/optimizer states/steps to a checkpoint file."""
tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank
tp_world_size, pp_world_size = pgm.process_group_manager.tp_world_size, pgm.process_group_manager.pp_world_size
ckpt_name = f"weights_tp_rank_world_size={tp_rank}_{tp_world_size}_pp_rank_world_size={pp_rank}_{pp_world_size}.pth"
path = os.path.join(out_dir, ckpt_name)
# Only DP/CP rank 0 will save the model, the weights are the same across all ranks
if pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0:
os.makedirs(out_dir, exist_ok=True)
raw_model = model.module if pgm.process_group_manager.cp_dp_world_size > 1 else model
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'trained_steps': trained_steps,
'trained_tokens': trained_tokens
}
torch.save(checkpoint, path)
def load_checkpoint(model, optimizer, out_dir):
"""Load the model/optimizer states from the latest checkpoint. Assume the topology is the same."""
tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank
tp_world_size, pp_world_size = pgm.process_group_manager.tp_world_size, pgm.process_group_manager.pp_world_size
ckpt_name = f"weights_tp_rank_world_size={tp_rank}_{tp_world_size}_pp_rank_world_size={pp_rank}_{pp_world_size}.pth"
path = os.path.join(out_dir, ckpt_name)
if not os.path.exists(path):
raise FileNotFoundError(f"Checkpoint not found at {path}")
checkpoint = torch.load(path)
# Load model weights
raw_model = model.module if pgm.process_group_manager.cp_dp_world_size > 1 else model
raw_model.load_state_dict(checkpoint['model'])
# Load optimizer state
optimizer.load_state_dict(checkpoint['optimizer'])
return checkpoint['trained_steps'], checkpoint['trained_tokens']

View File

@ -22,7 +22,8 @@ from transformers import AutoConfig
from picotron.context_parallel.context_parallel import apply_context_parallel from picotron.context_parallel.context_parallel import apply_context_parallel
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
import picotron.process_group_manager as pgm import picotron.process_group_manager as pgm
from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint from picotron.utils import set_all_seed, print, to_readable_format
from picotron.checkpoint import CheckpointManager
from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights
from picotron.data import MicroBatchDataLoader from picotron.data import MicroBatchDataLoader
from picotron.process_group_manager import setup_process_group_manager from picotron.process_group_manager import setup_process_group_manager
@ -213,10 +214,12 @@ if __name__ == "__main__":
extra_args = dict(fused=True) if use_fused else dict() extra_args = dict(fused=True) if use_fused else dict()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args) optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args)
checkpoint_manager = CheckpointManager()
trained_tokens, step = 0, 0 trained_tokens, step = 0, 0
if LOAD_PATH: if LOAD_PATH:
step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH) step, trained_tokens = checkpoint_manager.load_checkpoint(model, optimizer, LOAD_PATH)
dist.barrier() dist.barrier()
#TODO: Add activation checkpointing #TODO: Add activation checkpointing
@ -263,7 +266,7 @@ if __name__ == "__main__":
"memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens}) "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
if step % CHECKPOINT_FREQ == 0: if step % CHECKPOINT_FREQ == 0:
save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}") checkpoint_manager.save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")
if step >= TOTAL_TRAIN_STEPS: if step >= TOTAL_TRAIN_STEPS:
break break