more consistent naming

This commit is contained in:
ferdinand.mom 2024-12-01 19:45:11 +00:00
parent 32d8daa880
commit 804f43c97e
2 changed files with 3 additions and 4 deletions

View File

@ -45,7 +45,7 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False):
if include_buffers:
nn.Module.register_buffer = old_register_buffer
def initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path):
def init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path):
#Initialize model with correct tensor shapes but random weights
initialization_manager = InitializationManager(model, model_config)
layer_names = initialization_manager.get_layer_names_in_sft_format()

View File

@ -24,7 +24,7 @@ from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel
import picotron.process_group_manager as pgm
from picotron.utils import set_all_seed, print, to_readable_format
from picotron.checkpoint import CheckpointManager
from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights
from picotron.checkpoint import init_model_with_dematerialized_weights, init_model_with_materialized_weights
from picotron.data import MicroBatchDataLoader
from picotron.process_group_manager import setup_process_group_manager
from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
@ -180,8 +180,7 @@ if __name__ == "__main__":
if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, model_config)
#TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ?
model = initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"])
model = init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"])
if pgm.process_group_manager.cp_world_size > 1:
model = apply_context_parallel(model)