From 804f43c97e46e14018b8d4816511974cde7df858 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Sun, 1 Dec 2024 19:45:11 +0000 Subject: [PATCH] more consistent naming --- picotron/checkpoint.py | 2 +- train.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/picotron/checkpoint.py b/picotron/checkpoint.py index 18d6229..9b429c8 100644 --- a/picotron/checkpoint.py +++ b/picotron/checkpoint.py @@ -45,7 +45,7 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False): if include_buffers: nn.Module.register_buffer = old_register_buffer -def initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path): +def init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path): #Initialize model with correct tensor shapes but random weights initialization_manager = InitializationManager(model, model_config) layer_names = initialization_manager.get_layer_names_in_sft_format() diff --git a/train.py b/train.py index d0bc591..aa1e182 100644 --- a/train.py +++ b/train.py @@ -24,7 +24,7 @@ from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel import picotron.process_group_manager as pgm from picotron.utils import set_all_seed, print, to_readable_format from picotron.checkpoint import CheckpointManager -from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights +from picotron.checkpoint import init_model_with_dematerialized_weights, init_model_with_materialized_weights from picotron.data import MicroBatchDataLoader from picotron.process_group_manager import setup_process_group_manager from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel @@ -180,8 +180,7 @@ if __name__ == "__main__": if pgm.process_group_manager.pp_world_size > 1: model = PipelineParallel(model, model_config) - #TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ? - model = initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"]) + model = init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"]) if pgm.process_group_manager.cp_world_size > 1: model = apply_context_parallel(model)