leave out CP integration at the very end
This commit is contained in:
parent
d0d6d8994f
commit
83ddda2ce8
16
train.py
16
train.py
@ -8,7 +8,6 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --
|
|||||||
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
|
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import multiprocessing
|
|
||||||
import os
|
import os
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch, torch.distributed as dist
|
import torch, torch.distributed as dist
|
||||||
@ -26,10 +25,9 @@ from utils import set_all_seed, print
|
|||||||
from src.distributed.process_group_manager import setup_process_group_manager
|
from src.distributed.process_group_manager import setup_process_group_manager
|
||||||
from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||||
from src.parallel.data_parallel.data_parallel_bucket import DataParallel
|
from src.parallel.data_parallel.data_parallel_bucket import DataParallel
|
||||||
# from src.parallel.context_parallel import ContextParallel
|
from src.parallel.context_parallel import ContextParallel
|
||||||
from model import LLaMA
|
from model import Llama
|
||||||
import wandb
|
import wandb
|
||||||
import multiprocessing
|
|
||||||
|
|
||||||
class MicroBatchDataLoader(DataLoader):
|
class MicroBatchDataLoader(DataLoader):
|
||||||
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, grad_acc = 1, split="train", num_samples=None, num_workers=0):
|
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, grad_acc = 1, split="train", num_samples=None, num_workers=0):
|
||||||
@ -205,9 +203,7 @@ if __name__ == "__main__":
|
|||||||
config.num_attention_heads = 16
|
config.num_attention_heads = 16
|
||||||
config.num_key_value_heads = 4
|
config.num_key_value_heads = 4
|
||||||
|
|
||||||
model = LLaMA(
|
model = Llama(config=config)
|
||||||
config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
|
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
|
||||||
wandb.init(
|
wandb.init(
|
||||||
@ -231,13 +227,14 @@ if __name__ == "__main__":
|
|||||||
TensorParallel(model)
|
TensorParallel(model)
|
||||||
|
|
||||||
# if pgm.process_group_manager.cp_size > 1:
|
# if pgm.process_group_manager.cp_size > 1:
|
||||||
# model = ContextParallel(model, config)
|
#TODO: do at the very end when we have fix convergence issue
|
||||||
|
# model = ContextParallel(model, config)
|
||||||
|
|
||||||
if pgm.process_group_manager.pp_world_size > 1:
|
if pgm.process_group_manager.pp_world_size > 1:
|
||||||
model = PipelineParallel(model, config)
|
model = PipelineParallel(model, config)
|
||||||
|
|
||||||
if pgm.process_group_manager.dp_world_size > 1:
|
if pgm.process_group_manager.dp_world_size > 1:
|
||||||
model = DataParallel(model, pgm.process_group_manager.dp_group)
|
model = DataParallel(model)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.train()
|
model.train()
|
||||||
@ -267,6 +264,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# average the loss across all DP/CP ranks
|
# average the loss across all DP/CP ranks
|
||||||
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
|
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
|
||||||
|
#TODO: use all_reduce function from distributed_primitives.py
|
||||||
loss_tensor = torch.tensor([loss], dtype=torch.float32, device=device)
|
loss_tensor = torch.tensor([loss], dtype=torch.float32, device=device)
|
||||||
handle = dist.all_reduce(loss_tensor, group=pgm.process_group_manager.cp_dp_group, async_op=True, op=dist.ReduceOp.AVG)
|
handle = dist.all_reduce(loss_tensor, group=pgm.process_group_manager.cp_dp_group, async_op=True, op=dist.ReduceOp.AVG)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user