more consistent naming
This commit is contained in:
parent
32d8daa880
commit
804f43c97e
@ -45,7 +45,7 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False):
|
||||
if include_buffers:
|
||||
nn.Module.register_buffer = old_register_buffer
|
||||
|
||||
def initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path):
|
||||
def init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path):
|
||||
#Initialize model with correct tensor shapes but random weights
|
||||
initialization_manager = InitializationManager(model, model_config)
|
||||
layer_names = initialization_manager.get_layer_names_in_sft_format()
|
||||
|
||||
5
train.py
5
train.py
@ -24,7 +24,7 @@ from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.utils import set_all_seed, print, to_readable_format
|
||||
from picotron.checkpoint import CheckpointManager
|
||||
from picotron.checkpoint import init_model_with_dematerialized_weights, initialize_model_with_materialized_weights
|
||||
from picotron.checkpoint import init_model_with_dematerialized_weights, init_model_with_materialized_weights
|
||||
from picotron.data import MicroBatchDataLoader
|
||||
from picotron.process_group_manager import setup_process_group_manager
|
||||
from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
@ -180,8 +180,7 @@ if __name__ == "__main__":
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
model = PipelineParallel(model, model_config)
|
||||
|
||||
#TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ?
|
||||
model = initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"])
|
||||
model = init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"])
|
||||
|
||||
if pgm.process_group_manager.cp_world_size > 1:
|
||||
model = apply_context_parallel(model)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user