refactor checkpoint
This commit is contained in:
parent
5045be87e0
commit
daea1fed3f
@ -159,6 +159,50 @@ class InitializationManager:
|
||||
result = re.sub(pattern, replacement, result)
|
||||
return result
|
||||
|
||||
#TODO: Implement and Move save/load checkpoint here
|
||||
# class CheckpointManager:
|
||||
# pass
|
||||
class CheckpointManager:
|
||||
def __init__(self):
|
||||
self.tp_rank = pgm.process_group_manager.tp_rank
|
||||
self.pp_rank = pgm.process_group_manager.pp_rank
|
||||
self.tp_world_size = pgm.process_group_manager.tp_world_size
|
||||
self.pp_world_size = pgm.process_group_manager.pp_world_size
|
||||
self.cp_dp_world_size = pgm.process_group_manager.cp_dp_world_size
|
||||
self.dp_rank = pgm.process_group_manager.dp_rank
|
||||
self.cp_rank = pgm.process_group_manager.cp_rank
|
||||
|
||||
def _get_checkpoint_path(self, out_dir):
|
||||
ckpt_name = f"weights_tp_rank_world_size={self.tp_rank}_{self.tp_world_size}_pp_rank_world_size={self.pp_rank}_{self.pp_world_size}.pth"
|
||||
return os.path.join(out_dir, ckpt_name)
|
||||
|
||||
def save_checkpoint(self, model, optimizer, trained_steps, trained_tokens, out_dir):
|
||||
"""Save the model/optimizer states/steps to a checkpoint file."""
|
||||
path = self._get_checkpoint_path(out_dir)
|
||||
|
||||
# Only DP/CP rank 0 will save the model, the weights are the same across all ranks
|
||||
if self.dp_rank == 0 and self.cp_rank == 0:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
raw_model = model.module if self.cp_dp_world_size > 1 else model
|
||||
checkpoint = {
|
||||
'model': raw_model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'trained_steps': trained_steps,
|
||||
'trained_tokens': trained_tokens
|
||||
}
|
||||
torch.save(checkpoint, path)
|
||||
|
||||
def load_checkpoint(self, model, optimizer, out_dir):
|
||||
"""Load the model/optimizer states from the latest checkpoint. Assume the topology is the same."""
|
||||
path = self._get_checkpoint_path(out_dir)
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"Checkpoint not found at {path}")
|
||||
|
||||
checkpoint = torch.load(path)
|
||||
|
||||
# Load model weights
|
||||
raw_model = model.module if self.cp_dp_world_size > 1 else model
|
||||
raw_model.load_state_dict(checkpoint['model'])
|
||||
|
||||
# Load optimizer state
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
return checkpoint['trained_steps'], checkpoint['trained_tokens']
|
||||
|
||||
@ -1,12 +1,8 @@
|
||||
import os
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import builtins
|
||||
import fcntl
|
||||
import json
|
||||
import torch.nn as nn
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
def print(*args, is_print_rank=True, **kwargs):
|
||||
""" solves multi-process interleaved print problem """
|
||||
@ -45,40 +41,4 @@ def assert_no_meta_tensors(model):
|
||||
if buffer.device == torch.device("meta"):
|
||||
meta_tensors.append(f"Buffer '{name}' with shape {buffer.shape}")
|
||||
|
||||
assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors)
|
||||
|
||||
def save_checkpoint(model, optimizer, trained_steps, trained_tokens, out_dir):
|
||||
"""Save the model/optimizer states/steps to a checkpoint file."""
|
||||
tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank
|
||||
tp_world_size, pp_world_size = pgm.process_group_manager.tp_world_size, pgm.process_group_manager.pp_world_size
|
||||
ckpt_name = f"weights_tp_rank_world_size={tp_rank}_{tp_world_size}_pp_rank_world_size={pp_rank}_{pp_world_size}.pth"
|
||||
path = os.path.join(out_dir, ckpt_name)
|
||||
|
||||
# Only DP/CP rank 0 will save the model, the weights are the same across all ranks
|
||||
if pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
raw_model = model.module if pgm.process_group_manager.cp_dp_world_size > 1 else model
|
||||
checkpoint = {
|
||||
'model': raw_model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'trained_steps': trained_steps,
|
||||
'trained_tokens': trained_tokens
|
||||
}
|
||||
torch.save(checkpoint, path)
|
||||
|
||||
def load_checkpoint(model, optimizer, out_dir):
|
||||
"""Load the model/optimizer states from the latest checkpoint. Assume the topology is the same."""
|
||||
tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank
|
||||
tp_world_size, pp_world_size = pgm.process_group_manager.tp_world_size, pgm.process_group_manager.pp_world_size
|
||||
ckpt_name = f"weights_tp_rank_world_size={tp_rank}_{tp_world_size}_pp_rank_world_size={pp_rank}_{pp_world_size}.pth"
|
||||
path = os.path.join(out_dir, ckpt_name)
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"Checkpoint not found at {path}")
|
||||
checkpoint = torch.load(path)
|
||||
|
||||
# Load model weights
|
||||
raw_model = model.module if pgm.process_group_manager.cp_dp_world_size > 1 else model
|
||||
raw_model.load_state_dict(checkpoint['model'])
|
||||
# Load optimizer state
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
return checkpoint['trained_steps'], checkpoint['trained_tokens']
|
||||
assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors)
|
||||
9
train.py
9
train.py
@ -22,7 +22,8 @@ from transformers import AutoConfig
|
||||
from picotron.context_parallel.context_parallel import apply_context_parallel
|
||||
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
|
||||
from picotron.utils import set_all_seed, print, to_readable_format
|
||||
from picotron.checkpoint import CheckpointManager
|
||||
from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights
|
||||
from picotron.data import MicroBatchDataLoader
|
||||
from picotron.process_group_manager import setup_process_group_manager
|
||||
@ -213,10 +214,12 @@ if __name__ == "__main__":
|
||||
extra_args = dict(fused=True) if use_fused else dict()
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args)
|
||||
|
||||
checkpoint_manager = CheckpointManager()
|
||||
|
||||
trained_tokens, step = 0, 0
|
||||
if LOAD_PATH:
|
||||
step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH)
|
||||
step, trained_tokens = checkpoint_manager.load_checkpoint(model, optimizer, LOAD_PATH)
|
||||
|
||||
dist.barrier()
|
||||
#TODO: Add activation checkpointing
|
||||
@ -263,7 +266,7 @@ if __name__ == "__main__":
|
||||
"memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
|
||||
|
||||
if step % CHECKPOINT_FREQ == 0:
|
||||
save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")
|
||||
checkpoint_manager.save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")
|
||||
|
||||
if step >= TOTAL_TRAIN_STEPS:
|
||||
break
|
||||
|
||||
Loading…
Reference in New Issue
Block a user