diff --git a/picotron/checkpoint.py b/picotron/checkpoint.py index 18d6229..9b429c8 100644 --- a/picotron/checkpoint.py +++ b/picotron/checkpoint.py @@ -45,7 +45,7 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False): if include_buffers: nn.Module.register_buffer = old_register_buffer -def initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path): +def init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path): #Initialize model with correct tensor shapes but random weights initialization_manager = InitializationManager(model, model_config) layer_names = initialization_manager.get_layer_names_in_sft_format() diff --git a/train.py b/train.py index d0bc591..aa1e182 100644 --- a/train.py +++ b/train.py @@ -24,7 +24,7 @@ from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel import picotron.process_group_manager as pgm from picotron.utils import set_all_seed, print, to_readable_format from picotron.checkpoint import CheckpointManager -from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights +from picotron.checkpoint import init_model_with_dematerialized_weights, init_model_with_materialized_weights from picotron.data import MicroBatchDataLoader from picotron.process_group_manager import setup_process_group_manager from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel @@ -180,8 +180,7 @@ if __name__ == "__main__": if pgm.process_group_manager.pp_world_size > 1: model = PipelineParallel(model, model_config) - #TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ? - model = initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"]) + model = init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"]) if pgm.process_group_manager.cp_world_size > 1: model = apply_context_parallel(model)