From b8065de7aa96484e4b661b92e81a21d385a2b43f Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 26 Sep 2024 10:27:20 +0000 Subject: [PATCH] support CPU training through gloo backend --- data_parallel.py | 4 ++-- distributed_primtives.py | 8 ++++---- pipeline_parallel.py | 24 ++++++++++++------------ train.py | 33 ++++++++++++++++++++++++--------- utils.py | 19 ++++++++++++++++--- 5 files changed, 58 insertions(+), 30 deletions(-) diff --git a/data_parallel.py b/data_parallel.py index 5c7cc07..24e2cc0 100644 --- a/data_parallel.py +++ b/data_parallel.py @@ -20,5 +20,5 @@ class DataParallel(nn.Module): def all_reduce_gradients(self): for param in self.model.parameters(): if param.grad is not None: - dist.all_reduce(param.grad, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group) - \ No newline at end of file + param.grad /= self.dp_world_size + dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group) \ No newline at end of file diff --git a/distributed_primtives.py b/distributed_primtives.py index dc52795..824f735 100644 --- a/distributed_primtives.py +++ b/distributed_primtives.py @@ -5,19 +5,19 @@ import process_group_manager as pgm STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1" -def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None): +def communicate(operation, device, dtype, tensor=None, shapes=None): global STEP global VERBOSE if operation == 'recv_forward': if pgm.process_group_manager.pp_is_first_stage: return None - tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype) + tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype) src = pgm.process_group_manager.pp_prev_rank elif operation == 'send_forward': if pgm.process_group_manager.pp_is_last_stage: return dest = pgm.process_group_manager.pp_next_rank elif operation == 'recv_backward': if pgm.process_group_manager.pp_is_last_stage: return None - tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype) + tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype) src = pgm.process_group_manager.pp_next_rank elif operation == 'send_backward': if pgm.process_group_manager.pp_is_first_stage: return @@ -31,7 +31,7 @@ def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None): if VERBOSE: STEP += 1 return tensor if not is_send else None -def bidirectional_communicate(operation, send_tensor, recv_shapes, dtype, device): +def bidirectional_communicate(operation, send_tensor, recv_shapes, device, dtype): global STEP global VERBOSE is_fwd = (operation == 'send_fwd_recv_bwd') diff --git a/pipeline_parallel.py b/pipeline_parallel.py index 3e4d3f0..dbfb1f7 100644 --- a/pipeline_parallel.py +++ b/pipeline_parallel.py @@ -44,11 +44,11 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device): input_tensors, output_tensors = [], [] for _ in range(data_loader.num_local_micro_batches): # All forward passes - input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32) + input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) batch = next(iter(data_loader)) batch["hidden_states"] = input_tensor output_tensor = model.forward(batch, device) - communicate(operation='send_forward', tensor=output_tensor) + communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32) # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank: @@ -59,10 +59,10 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device): output_tensors.append(output_tensor) for _ in range(data_loader.num_local_micro_batches): # All backward passes - output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32) + output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) - communicate(operation='send_backward', tensor=input_tensor_grad) + communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) logging_loss = reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss @@ -84,33 +84,33 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): return output_tensor for _ in range(num_warmup_microbatches): # Warmup forward passes - input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32) + input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) output_tensor = _forward_step(input_tensor) - communicate(operation='send_forward', tensor=output_tensor) + communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32) input_tensors.append(input_tensor) output_tensors.append(output_tensor) if num_microbatches_remaining > 0: - input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32) + input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) for i in range(num_microbatches_remaining): # 1F1B steady state output_tensor = _forward_step(input_tensor) - output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, dtype=torch.float32, device=device) + output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=torch.float32) input_tensors.append(input_tensor) output_tensors.append(output_tensor) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) if i == num_microbatches_remaining - 1: # last iteration input_tensor = None - communicate(operation='send_backward', tensor=input_tensor_grad) + communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) else: - input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, dtype=torch.float32, device=device) + input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.float32) for _ in range(num_warmup_microbatches): # Cooldown backward passes input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) - output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32) + output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) - communicate(operation='send_backward', tensor=input_tensor_grad) + communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) logging_loss = reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss \ No newline at end of file diff --git a/train.py b/train.py index 821d0a4..780942b 100644 --- a/train.py +++ b/train.py @@ -8,7 +8,7 @@ from transformers import AutoConfig, AutoModelForCausalLM import argparse import process_group_manager as pgm -from utils import set_all_seed, display_parallelism_grid +from utils import set_all_seed, display_parallelism_grid, print from process_group_manager import setup_process_group_manager from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel from data_parallel import DataParallel @@ -46,18 +46,33 @@ if __name__ == "__main__": parser.add_argument("--pp_size", type=int, default=1) parser.add_argument("--dp_size", type=int, default=1) parser.add_argument("--use_wandb", action="store_true", default=False) + parser.add_argument("--use_cpu", action="store_true", default=False) + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=int, default=29500) args = parser.parse_args() + os.environ["OMP_NUM_THREADS"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" - local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) - host, port = os.environ["MASTER_ADDR"], int(os.environ["MASTER_PORT"]) - + + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + host = os.environ["MASTER_ADDR"] + port = int(os.environ["MASTER_PORT"]) + SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42 - - dist.init_process_group(rank=local_rank, world_size=world_size, backend="nccl", init_method=f"tcp://{host}:{port}") - torch.cuda.set_device(local_rank) - device = torch.device("cuda", local_rank) + + backend = "gloo" if args.use_cpu else "nccl" + + if backend == "nccl": + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + else: + device = torch.device("cpu") + + + dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}") + setup_process_group_manager(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size) if pgm.process_group_manager.global_rank == local_rank: @@ -73,9 +88,9 @@ if __name__ == "__main__": project="picotron", name=f"test_convergence_{pgm.process_group_manager}", config={ - "data_parallel_size": pgm.process_group_manager.dp_size, "tensor_parallel_size": pgm.process_group_manager.tp_size, "pipeline_parallel_size": pgm.process_group_manager.pp_size, + "data_parallel_size": pgm.process_group_manager.dp_size, "model": model_name, "dataset": dataset_name, "max_tokens": MAX_TOKENS, diff --git a/utils.py b/utils.py index 20a2481..807ce86 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,19 @@ -import torch, random, numpy as np +import torch +import random +import numpy as np +import builtins +import fcntl import process_group_manager as pgm +def print(*args, **kwargs): + """ solves multi-process interleaved print problem """ + with open(__file__, "r") as fh: + fcntl.flock(fh, fcntl.LOCK_EX) + try: + builtins.print(*args, **kwargs) + finally: + fcntl.flock(fh, fcntl.LOCK_UN) + def set_all_seed(seed): for module in [random, np.random]: module.seed(seed) torch.manual_seed(seed) @@ -50,8 +63,8 @@ def display_parallelism_grid(): output.append(f"=== Local Parallelism Configuration ===") output.append(pgm.process_group_manager.__str__()) - output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.dp_group_ids]}") - output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.pp_group_ids]}") output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.tp_group_ids]}") + output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.pp_group_ids]}") + output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.dp_group_ids]}") print("\n".join(output))