refactor checkpoint
This commit is contained in:
parent
5045be87e0
commit
daea1fed3f
@ -159,6 +159,50 @@ class InitializationManager:
|
|||||||
result = re.sub(pattern, replacement, result)
|
result = re.sub(pattern, replacement, result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
#TODO: Implement and Move save/load checkpoint here
|
class CheckpointManager:
|
||||||
# class CheckpointManager:
|
def __init__(self):
|
||||||
# pass
|
self.tp_rank = pgm.process_group_manager.tp_rank
|
||||||
|
self.pp_rank = pgm.process_group_manager.pp_rank
|
||||||
|
self.tp_world_size = pgm.process_group_manager.tp_world_size
|
||||||
|
self.pp_world_size = pgm.process_group_manager.pp_world_size
|
||||||
|
self.cp_dp_world_size = pgm.process_group_manager.cp_dp_world_size
|
||||||
|
self.dp_rank = pgm.process_group_manager.dp_rank
|
||||||
|
self.cp_rank = pgm.process_group_manager.cp_rank
|
||||||
|
|
||||||
|
def _get_checkpoint_path(self, out_dir):
|
||||||
|
ckpt_name = f"weights_tp_rank_world_size={self.tp_rank}_{self.tp_world_size}_pp_rank_world_size={self.pp_rank}_{self.pp_world_size}.pth"
|
||||||
|
return os.path.join(out_dir, ckpt_name)
|
||||||
|
|
||||||
|
def save_checkpoint(self, model, optimizer, trained_steps, trained_tokens, out_dir):
|
||||||
|
"""Save the model/optimizer states/steps to a checkpoint file."""
|
||||||
|
path = self._get_checkpoint_path(out_dir)
|
||||||
|
|
||||||
|
# Only DP/CP rank 0 will save the model, the weights are the same across all ranks
|
||||||
|
if self.dp_rank == 0 and self.cp_rank == 0:
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
raw_model = model.module if self.cp_dp_world_size > 1 else model
|
||||||
|
checkpoint = {
|
||||||
|
'model': raw_model.state_dict(),
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'trained_steps': trained_steps,
|
||||||
|
'trained_tokens': trained_tokens
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, path)
|
||||||
|
|
||||||
|
def load_checkpoint(self, model, optimizer, out_dir):
|
||||||
|
"""Load the model/optimizer states from the latest checkpoint. Assume the topology is the same."""
|
||||||
|
path = self._get_checkpoint_path(out_dir)
|
||||||
|
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise FileNotFoundError(f"Checkpoint not found at {path}")
|
||||||
|
|
||||||
|
checkpoint = torch.load(path)
|
||||||
|
|
||||||
|
# Load model weights
|
||||||
|
raw_model = model.module if self.cp_dp_world_size > 1 else model
|
||||||
|
raw_model.load_state_dict(checkpoint['model'])
|
||||||
|
|
||||||
|
# Load optimizer state
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
|
||||||
|
return checkpoint['trained_steps'], checkpoint['trained_tokens']
|
||||||
|
|||||||
@ -1,12 +1,8 @@
|
|||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import builtins
|
import builtins
|
||||||
import fcntl
|
import fcntl
|
||||||
import json
|
|
||||||
import torch.nn as nn
|
|
||||||
import picotron.process_group_manager as pgm
|
|
||||||
|
|
||||||
def print(*args, is_print_rank=True, **kwargs):
|
def print(*args, is_print_rank=True, **kwargs):
|
||||||
""" solves multi-process interleaved print problem """
|
""" solves multi-process interleaved print problem """
|
||||||
@ -45,40 +41,4 @@ def assert_no_meta_tensors(model):
|
|||||||
if buffer.device == torch.device("meta"):
|
if buffer.device == torch.device("meta"):
|
||||||
meta_tensors.append(f"Buffer '{name}' with shape {buffer.shape}")
|
meta_tensors.append(f"Buffer '{name}' with shape {buffer.shape}")
|
||||||
|
|
||||||
assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors)
|
assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors)
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, trained_steps, trained_tokens, out_dir):
|
|
||||||
"""Save the model/optimizer states/steps to a checkpoint file."""
|
|
||||||
tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank
|
|
||||||
tp_world_size, pp_world_size = pgm.process_group_manager.tp_world_size, pgm.process_group_manager.pp_world_size
|
|
||||||
ckpt_name = f"weights_tp_rank_world_size={tp_rank}_{tp_world_size}_pp_rank_world_size={pp_rank}_{pp_world_size}.pth"
|
|
||||||
path = os.path.join(out_dir, ckpt_name)
|
|
||||||
|
|
||||||
# Only DP/CP rank 0 will save the model, the weights are the same across all ranks
|
|
||||||
if pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0:
|
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
|
||||||
raw_model = model.module if pgm.process_group_manager.cp_dp_world_size > 1 else model
|
|
||||||
checkpoint = {
|
|
||||||
'model': raw_model.state_dict(),
|
|
||||||
'optimizer': optimizer.state_dict(),
|
|
||||||
'trained_steps': trained_steps,
|
|
||||||
'trained_tokens': trained_tokens
|
|
||||||
}
|
|
||||||
torch.save(checkpoint, path)
|
|
||||||
|
|
||||||
def load_checkpoint(model, optimizer, out_dir):
|
|
||||||
"""Load the model/optimizer states from the latest checkpoint. Assume the topology is the same."""
|
|
||||||
tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank
|
|
||||||
tp_world_size, pp_world_size = pgm.process_group_manager.tp_world_size, pgm.process_group_manager.pp_world_size
|
|
||||||
ckpt_name = f"weights_tp_rank_world_size={tp_rank}_{tp_world_size}_pp_rank_world_size={pp_rank}_{pp_world_size}.pth"
|
|
||||||
path = os.path.join(out_dir, ckpt_name)
|
|
||||||
if not os.path.exists(path):
|
|
||||||
raise FileNotFoundError(f"Checkpoint not found at {path}")
|
|
||||||
checkpoint = torch.load(path)
|
|
||||||
|
|
||||||
# Load model weights
|
|
||||||
raw_model = model.module if pgm.process_group_manager.cp_dp_world_size > 1 else model
|
|
||||||
raw_model.load_state_dict(checkpoint['model'])
|
|
||||||
# Load optimizer state
|
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
||||||
return checkpoint['trained_steps'], checkpoint['trained_tokens']
|
|
||||||
9
train.py
9
train.py
@ -22,7 +22,8 @@ from transformers import AutoConfig
|
|||||||
from picotron.context_parallel.context_parallel import apply_context_parallel
|
from picotron.context_parallel.context_parallel import apply_context_parallel
|
||||||
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
|
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
|
||||||
import picotron.process_group_manager as pgm
|
import picotron.process_group_manager as pgm
|
||||||
from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
|
from picotron.utils import set_all_seed, print, to_readable_format
|
||||||
|
from picotron.checkpoint import CheckpointManager
|
||||||
from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights
|
from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights
|
||||||
from picotron.data import MicroBatchDataLoader
|
from picotron.data import MicroBatchDataLoader
|
||||||
from picotron.process_group_manager import setup_process_group_manager
|
from picotron.process_group_manager import setup_process_group_manager
|
||||||
@ -213,10 +214,12 @@ if __name__ == "__main__":
|
|||||||
extra_args = dict(fused=True) if use_fused else dict()
|
extra_args = dict(fused=True) if use_fused else dict()
|
||||||
|
|
||||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args)
|
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args)
|
||||||
|
|
||||||
|
checkpoint_manager = CheckpointManager()
|
||||||
|
|
||||||
trained_tokens, step = 0, 0
|
trained_tokens, step = 0, 0
|
||||||
if LOAD_PATH:
|
if LOAD_PATH:
|
||||||
step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH)
|
step, trained_tokens = checkpoint_manager.load_checkpoint(model, optimizer, LOAD_PATH)
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
#TODO: Add activation checkpointing
|
#TODO: Add activation checkpointing
|
||||||
@ -263,7 +266,7 @@ if __name__ == "__main__":
|
|||||||
"memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
|
"memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
|
||||||
|
|
||||||
if step % CHECKPOINT_FREQ == 0:
|
if step % CHECKPOINT_FREQ == 0:
|
||||||
save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")
|
checkpoint_manager.save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")
|
||||||
|
|
||||||
if step >= TOTAL_TRAIN_STEPS:
|
if step >= TOTAL_TRAIN_STEPS:
|
||||||
break
|
break
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user