support CPU training through gloo backend
This commit is contained in:
parent
6f6bc1945a
commit
b8065de7aa
@ -20,5 +20,5 @@ class DataParallel(nn.Module):
|
||||
def all_reduce_gradients(self):
|
||||
for param in self.model.parameters():
|
||||
if param.grad is not None:
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group)
|
||||
|
||||
param.grad /= self.dp_world_size
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group)
|
||||
@ -5,19 +5,19 @@ import process_group_manager as pgm
|
||||
|
||||
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
|
||||
|
||||
def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None):
|
||||
def communicate(operation, device, dtype, tensor=None, shapes=None):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
if operation == 'recv_forward':
|
||||
if pgm.process_group_manager.pp_is_first_stage: return None
|
||||
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
|
||||
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
|
||||
src = pgm.process_group_manager.pp_prev_rank
|
||||
elif operation == 'send_forward':
|
||||
if pgm.process_group_manager.pp_is_last_stage: return
|
||||
dest = pgm.process_group_manager.pp_next_rank
|
||||
elif operation == 'recv_backward':
|
||||
if pgm.process_group_manager.pp_is_last_stage: return None
|
||||
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
|
||||
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
|
||||
src = pgm.process_group_manager.pp_next_rank
|
||||
elif operation == 'send_backward':
|
||||
if pgm.process_group_manager.pp_is_first_stage: return
|
||||
@ -31,7 +31,7 @@ def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None):
|
||||
if VERBOSE: STEP += 1
|
||||
return tensor if not is_send else None
|
||||
|
||||
def bidirectional_communicate(operation, send_tensor, recv_shapes, dtype, device):
|
||||
def bidirectional_communicate(operation, send_tensor, recv_shapes, device, dtype):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
is_fwd = (operation == 'send_fwd_recv_bwd')
|
||||
|
||||
@ -44,11 +44,11 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
||||
input_tensors, output_tensors = [], []
|
||||
|
||||
for _ in range(data_loader.num_local_micro_batches): # All forward passes
|
||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32)
|
||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
batch = next(iter(data_loader))
|
||||
batch["hidden_states"] = input_tensor
|
||||
output_tensor = model.forward(batch, device)
|
||||
communicate(operation='send_forward', tensor=output_tensor)
|
||||
communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
|
||||
|
||||
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
|
||||
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
|
||||
@ -59,10 +59,10 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
for _ in range(data_loader.num_local_micro_batches): # All backward passes
|
||||
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32)
|
||||
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad)
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||
|
||||
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
|
||||
return logging_loss
|
||||
@ -84,33 +84,33 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
||||
return output_tensor
|
||||
|
||||
for _ in range(num_warmup_microbatches): # Warmup forward passes
|
||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32)
|
||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
output_tensor = _forward_step(input_tensor)
|
||||
communicate(operation='send_forward', tensor=output_tensor)
|
||||
communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
if num_microbatches_remaining > 0:
|
||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32)
|
||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
|
||||
for i in range(num_microbatches_remaining): # 1F1B steady state
|
||||
output_tensor = _forward_step(input_tensor)
|
||||
output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, dtype=torch.float32, device=device)
|
||||
output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
if i == num_microbatches_remaining - 1: # last iteration
|
||||
input_tensor = None
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad)
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||
else:
|
||||
input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, dtype=torch.float32, device=device)
|
||||
input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
|
||||
for _ in range(num_warmup_microbatches): # Cooldown backward passes
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32)
|
||||
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad)
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||
|
||||
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
|
||||
return logging_loss
|
||||
33
train.py
33
train.py
@ -8,7 +8,7 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
||||
import argparse
|
||||
|
||||
import process_group_manager as pgm
|
||||
from utils import set_all_seed, display_parallelism_grid
|
||||
from utils import set_all_seed, display_parallelism_grid, print
|
||||
from process_group_manager import setup_process_group_manager
|
||||
from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from data_parallel import DataParallel
|
||||
@ -46,18 +46,33 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--pp_size", type=int, default=1)
|
||||
parser.add_argument("--dp_size", type=int, default=1)
|
||||
parser.add_argument("--use_wandb", action="store_true", default=False)
|
||||
parser.add_argument("--use_cpu", action="store_true", default=False)
|
||||
parser.add_argument("--master_addr", type=str, default="localhost")
|
||||
parser.add_argument("--master_port", type=int, default=29500)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
||||
host, port = os.environ["MASTER_ADDR"], int(os.environ["MASTER_PORT"])
|
||||
|
||||
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
host = os.environ["MASTER_ADDR"]
|
||||
port = int(os.environ["MASTER_PORT"])
|
||||
|
||||
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
|
||||
|
||||
dist.init_process_group(rank=local_rank, world_size=world_size, backend="nccl", init_method=f"tcp://{host}:{port}")
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
|
||||
backend = "gloo" if args.use_cpu else "nccl"
|
||||
|
||||
if backend == "nccl":
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
|
||||
dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}")
|
||||
|
||||
setup_process_group_manager(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size)
|
||||
|
||||
if pgm.process_group_manager.global_rank == local_rank:
|
||||
@ -73,9 +88,9 @@ if __name__ == "__main__":
|
||||
project="picotron",
|
||||
name=f"test_convergence_{pgm.process_group_manager}",
|
||||
config={
|
||||
"data_parallel_size": pgm.process_group_manager.dp_size,
|
||||
"tensor_parallel_size": pgm.process_group_manager.tp_size,
|
||||
"pipeline_parallel_size": pgm.process_group_manager.pp_size,
|
||||
"data_parallel_size": pgm.process_group_manager.dp_size,
|
||||
"model": model_name,
|
||||
"dataset": dataset_name,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
|
||||
19
utils.py
19
utils.py
@ -1,6 +1,19 @@
|
||||
import torch, random, numpy as np
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import builtins
|
||||
import fcntl
|
||||
import process_group_manager as pgm
|
||||
|
||||
def print(*args, **kwargs):
|
||||
""" solves multi-process interleaved print problem """
|
||||
with open(__file__, "r") as fh:
|
||||
fcntl.flock(fh, fcntl.LOCK_EX)
|
||||
try:
|
||||
builtins.print(*args, **kwargs)
|
||||
finally:
|
||||
fcntl.flock(fh, fcntl.LOCK_UN)
|
||||
|
||||
def set_all_seed(seed):
|
||||
for module in [random, np.random]: module.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
@ -50,8 +63,8 @@ def display_parallelism_grid():
|
||||
|
||||
output.append(f"=== Local Parallelism Configuration ===")
|
||||
output.append(pgm.process_group_manager.__str__())
|
||||
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.dp_group_ids]}")
|
||||
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.pp_group_ids]}")
|
||||
output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.tp_group_ids]}")
|
||||
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.pp_group_ids]}")
|
||||
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.dp_group_ids]}")
|
||||
|
||||
print("\n".join(output))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user