support CPU training through gloo backend

This commit is contained in:
ferdinand.mom 2024-09-26 10:27:20 +00:00
parent 6f6bc1945a
commit b8065de7aa
5 changed files with 58 additions and 30 deletions

View File

@ -20,5 +20,5 @@ class DataParallel(nn.Module):
def all_reduce_gradients(self):
for param in self.model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group)
param.grad /= self.dp_world_size
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group)

View File

@ -5,19 +5,19 @@ import process_group_manager as pgm
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None):
def communicate(operation, device, dtype, tensor=None, shapes=None):
global STEP
global VERBOSE
if operation == 'recv_forward':
if pgm.process_group_manager.pp_is_first_stage: return None
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
src = pgm.process_group_manager.pp_prev_rank
elif operation == 'send_forward':
if pgm.process_group_manager.pp_is_last_stage: return
dest = pgm.process_group_manager.pp_next_rank
elif operation == 'recv_backward':
if pgm.process_group_manager.pp_is_last_stage: return None
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
src = pgm.process_group_manager.pp_next_rank
elif operation == 'send_backward':
if pgm.process_group_manager.pp_is_first_stage: return
@ -31,7 +31,7 @@ def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None):
if VERBOSE: STEP += 1
return tensor if not is_send else None
def bidirectional_communicate(operation, send_tensor, recv_shapes, dtype, device):
def bidirectional_communicate(operation, send_tensor, recv_shapes, device, dtype):
global STEP
global VERBOSE
is_fwd = (operation == 'send_fwd_recv_bwd')

View File

@ -44,11 +44,11 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
input_tensors, output_tensors = [], []
for _ in range(data_loader.num_local_micro_batches): # All forward passes
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32)
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
batch = next(iter(data_loader))
batch["hidden_states"] = input_tensor
output_tensor = model.forward(batch, device)
communicate(operation='send_forward', tensor=output_tensor)
communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
@ -59,10 +59,10 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
output_tensors.append(output_tensor)
for _ in range(data_loader.num_local_micro_batches): # All backward passes
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32)
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
communicate(operation='send_backward', tensor=input_tensor_grad)
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
return logging_loss
@ -84,33 +84,33 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
return output_tensor
for _ in range(num_warmup_microbatches): # Warmup forward passes
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32)
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
output_tensor = _forward_step(input_tensor)
communicate(operation='send_forward', tensor=output_tensor)
communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if num_microbatches_remaining > 0:
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32)
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
for i in range(num_microbatches_remaining): # 1F1B steady state
output_tensor = _forward_step(input_tensor)
output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, dtype=torch.float32, device=device)
output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
if i == num_microbatches_remaining - 1: # last iteration
input_tensor = None
communicate(operation='send_backward', tensor=input_tensor_grad)
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
else:
input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, dtype=torch.float32, device=device)
input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
for _ in range(num_warmup_microbatches): # Cooldown backward passes
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32)
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
communicate(operation='send_backward', tensor=input_tensor_grad)
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
return logging_loss

View File

@ -8,7 +8,7 @@ from transformers import AutoConfig, AutoModelForCausalLM
import argparse
import process_group_manager as pgm
from utils import set_all_seed, display_parallelism_grid
from utils import set_all_seed, display_parallelism_grid, print
from process_group_manager import setup_process_group_manager
from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from data_parallel import DataParallel
@ -46,18 +46,33 @@ if __name__ == "__main__":
parser.add_argument("--pp_size", type=int, default=1)
parser.add_argument("--dp_size", type=int, default=1)
parser.add_argument("--use_wandb", action="store_true", default=False)
parser.add_argument("--use_cpu", action="store_true", default=False)
parser.add_argument("--master_addr", type=str, default="localhost")
parser.add_argument("--master_port", type=int, default=29500)
args = parser.parse_args()
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
host, port = os.environ["MASTER_ADDR"], int(os.environ["MASTER_PORT"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
host = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"])
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
dist.init_process_group(rank=local_rank, world_size=world_size, backend="nccl", init_method=f"tcp://{host}:{port}")
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
backend = "gloo" if args.use_cpu else "nccl"
if backend == "nccl":
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
else:
device = torch.device("cpu")
dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}")
setup_process_group_manager(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size)
if pgm.process_group_manager.global_rank == local_rank:
@ -73,9 +88,9 @@ if __name__ == "__main__":
project="picotron",
name=f"test_convergence_{pgm.process_group_manager}",
config={
"data_parallel_size": pgm.process_group_manager.dp_size,
"tensor_parallel_size": pgm.process_group_manager.tp_size,
"pipeline_parallel_size": pgm.process_group_manager.pp_size,
"data_parallel_size": pgm.process_group_manager.dp_size,
"model": model_name,
"dataset": dataset_name,
"max_tokens": MAX_TOKENS,

View File

@ -1,6 +1,19 @@
import torch, random, numpy as np
import torch
import random
import numpy as np
import builtins
import fcntl
import process_group_manager as pgm
def print(*args, **kwargs):
""" solves multi-process interleaved print problem """
with open(__file__, "r") as fh:
fcntl.flock(fh, fcntl.LOCK_EX)
try:
builtins.print(*args, **kwargs)
finally:
fcntl.flock(fh, fcntl.LOCK_UN)
def set_all_seed(seed):
for module in [random, np.random]: module.seed(seed)
torch.manual_seed(seed)
@ -50,8 +63,8 @@ def display_parallelism_grid():
output.append(f"=== Local Parallelism Configuration ===")
output.append(pgm.process_group_manager.__str__())
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.dp_group_ids]}")
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.pp_group_ids]}")
output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.tp_group_ids]}")
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.pp_group_ids]}")
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.dp_group_ids]}")
print("\n".join(output))