support CPU training through gloo backend
This commit is contained in:
parent
6f6bc1945a
commit
b8065de7aa
@ -20,5 +20,5 @@ class DataParallel(nn.Module):
|
|||||||
def all_reduce_gradients(self):
|
def all_reduce_gradients(self):
|
||||||
for param in self.model.parameters():
|
for param in self.model.parameters():
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group)
|
param.grad /= self.dp_world_size
|
||||||
|
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group)
|
||||||
@ -5,19 +5,19 @@ import process_group_manager as pgm
|
|||||||
|
|
||||||
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
|
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
|
||||||
|
|
||||||
def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None):
|
def communicate(operation, device, dtype, tensor=None, shapes=None):
|
||||||
global STEP
|
global STEP
|
||||||
global VERBOSE
|
global VERBOSE
|
||||||
if operation == 'recv_forward':
|
if operation == 'recv_forward':
|
||||||
if pgm.process_group_manager.pp_is_first_stage: return None
|
if pgm.process_group_manager.pp_is_first_stage: return None
|
||||||
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
|
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
|
||||||
src = pgm.process_group_manager.pp_prev_rank
|
src = pgm.process_group_manager.pp_prev_rank
|
||||||
elif operation == 'send_forward':
|
elif operation == 'send_forward':
|
||||||
if pgm.process_group_manager.pp_is_last_stage: return
|
if pgm.process_group_manager.pp_is_last_stage: return
|
||||||
dest = pgm.process_group_manager.pp_next_rank
|
dest = pgm.process_group_manager.pp_next_rank
|
||||||
elif operation == 'recv_backward':
|
elif operation == 'recv_backward':
|
||||||
if pgm.process_group_manager.pp_is_last_stage: return None
|
if pgm.process_group_manager.pp_is_last_stage: return None
|
||||||
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
|
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
|
||||||
src = pgm.process_group_manager.pp_next_rank
|
src = pgm.process_group_manager.pp_next_rank
|
||||||
elif operation == 'send_backward':
|
elif operation == 'send_backward':
|
||||||
if pgm.process_group_manager.pp_is_first_stage: return
|
if pgm.process_group_manager.pp_is_first_stage: return
|
||||||
@ -31,7 +31,7 @@ def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None):
|
|||||||
if VERBOSE: STEP += 1
|
if VERBOSE: STEP += 1
|
||||||
return tensor if not is_send else None
|
return tensor if not is_send else None
|
||||||
|
|
||||||
def bidirectional_communicate(operation, send_tensor, recv_shapes, dtype, device):
|
def bidirectional_communicate(operation, send_tensor, recv_shapes, device, dtype):
|
||||||
global STEP
|
global STEP
|
||||||
global VERBOSE
|
global VERBOSE
|
||||||
is_fwd = (operation == 'send_fwd_recv_bwd')
|
is_fwd = (operation == 'send_fwd_recv_bwd')
|
||||||
|
|||||||
@ -44,11 +44,11 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
|||||||
input_tensors, output_tensors = [], []
|
input_tensors, output_tensors = [], []
|
||||||
|
|
||||||
for _ in range(data_loader.num_local_micro_batches): # All forward passes
|
for _ in range(data_loader.num_local_micro_batches): # All forward passes
|
||||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32)
|
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||||
batch = next(iter(data_loader))
|
batch = next(iter(data_loader))
|
||||||
batch["hidden_states"] = input_tensor
|
batch["hidden_states"] = input_tensor
|
||||||
output_tensor = model.forward(batch, device)
|
output_tensor = model.forward(batch, device)
|
||||||
communicate(operation='send_forward', tensor=output_tensor)
|
communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
|
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
|
||||||
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
|
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
|
||||||
@ -59,10 +59,10 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
|||||||
output_tensors.append(output_tensor)
|
output_tensors.append(output_tensor)
|
||||||
|
|
||||||
for _ in range(data_loader.num_local_micro_batches): # All backward passes
|
for _ in range(data_loader.num_local_micro_batches): # All backward passes
|
||||||
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32)
|
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||||
communicate(operation='send_backward', tensor=input_tensor_grad)
|
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
|
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
|
||||||
return logging_loss
|
return logging_loss
|
||||||
@ -84,33 +84,33 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
|||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
for _ in range(num_warmup_microbatches): # Warmup forward passes
|
for _ in range(num_warmup_microbatches): # Warmup forward passes
|
||||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32)
|
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||||
output_tensor = _forward_step(input_tensor)
|
output_tensor = _forward_step(input_tensor)
|
||||||
communicate(operation='send_forward', tensor=output_tensor)
|
communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
|
||||||
input_tensors.append(input_tensor)
|
input_tensors.append(input_tensor)
|
||||||
output_tensors.append(output_tensor)
|
output_tensors.append(output_tensor)
|
||||||
|
|
||||||
if num_microbatches_remaining > 0:
|
if num_microbatches_remaining > 0:
|
||||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32)
|
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
for i in range(num_microbatches_remaining): # 1F1B steady state
|
for i in range(num_microbatches_remaining): # 1F1B steady state
|
||||||
output_tensor = _forward_step(input_tensor)
|
output_tensor = _forward_step(input_tensor)
|
||||||
output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, dtype=torch.float32, device=device)
|
output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||||
input_tensors.append(input_tensor)
|
input_tensors.append(input_tensor)
|
||||||
output_tensors.append(output_tensor)
|
output_tensors.append(output_tensor)
|
||||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||||
if i == num_microbatches_remaining - 1: # last iteration
|
if i == num_microbatches_remaining - 1: # last iteration
|
||||||
input_tensor = None
|
input_tensor = None
|
||||||
communicate(operation='send_backward', tensor=input_tensor_grad)
|
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, dtype=torch.float32, device=device)
|
input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
for _ in range(num_warmup_microbatches): # Cooldown backward passes
|
for _ in range(num_warmup_microbatches): # Cooldown backward passes
|
||||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||||
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32)
|
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||||
communicate(operation='send_backward', tensor=input_tensor_grad)
|
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
|
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
|
||||||
return logging_loss
|
return logging_loss
|
||||||
33
train.py
33
train.py
@ -8,7 +8,7 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import process_group_manager as pgm
|
import process_group_manager as pgm
|
||||||
from utils import set_all_seed, display_parallelism_grid
|
from utils import set_all_seed, display_parallelism_grid, print
|
||||||
from process_group_manager import setup_process_group_manager
|
from process_group_manager import setup_process_group_manager
|
||||||
from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||||
from data_parallel import DataParallel
|
from data_parallel import DataParallel
|
||||||
@ -46,18 +46,33 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--pp_size", type=int, default=1)
|
parser.add_argument("--pp_size", type=int, default=1)
|
||||||
parser.add_argument("--dp_size", type=int, default=1)
|
parser.add_argument("--dp_size", type=int, default=1)
|
||||||
parser.add_argument("--use_wandb", action="store_true", default=False)
|
parser.add_argument("--use_wandb", action="store_true", default=False)
|
||||||
|
parser.add_argument("--use_cpu", action="store_true", default=False)
|
||||||
|
parser.add_argument("--master_addr", type=str, default="localhost")
|
||||||
|
parser.add_argument("--master_port", type=int, default=29500)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
os.environ["OMP_NUM_THREADS"] = "1"
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
|
||||||
host, port = os.environ["MASTER_ADDR"], int(os.environ["MASTER_PORT"])
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
host = os.environ["MASTER_ADDR"]
|
||||||
|
port = int(os.environ["MASTER_PORT"])
|
||||||
|
|
||||||
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
|
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
|
||||||
|
|
||||||
dist.init_process_group(rank=local_rank, world_size=world_size, backend="nccl", init_method=f"tcp://{host}:{port}")
|
backend = "gloo" if args.use_cpu else "nccl"
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
device = torch.device("cuda", local_rank)
|
if backend == "nccl":
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
device = torch.device("cuda", local_rank)
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}")
|
||||||
|
|
||||||
setup_process_group_manager(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size)
|
setup_process_group_manager(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size)
|
||||||
|
|
||||||
if pgm.process_group_manager.global_rank == local_rank:
|
if pgm.process_group_manager.global_rank == local_rank:
|
||||||
@ -73,9 +88,9 @@ if __name__ == "__main__":
|
|||||||
project="picotron",
|
project="picotron",
|
||||||
name=f"test_convergence_{pgm.process_group_manager}",
|
name=f"test_convergence_{pgm.process_group_manager}",
|
||||||
config={
|
config={
|
||||||
"data_parallel_size": pgm.process_group_manager.dp_size,
|
|
||||||
"tensor_parallel_size": pgm.process_group_manager.tp_size,
|
"tensor_parallel_size": pgm.process_group_manager.tp_size,
|
||||||
"pipeline_parallel_size": pgm.process_group_manager.pp_size,
|
"pipeline_parallel_size": pgm.process_group_manager.pp_size,
|
||||||
|
"data_parallel_size": pgm.process_group_manager.dp_size,
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"dataset": dataset_name,
|
"dataset": dataset_name,
|
||||||
"max_tokens": MAX_TOKENS,
|
"max_tokens": MAX_TOKENS,
|
||||||
|
|||||||
19
utils.py
19
utils.py
@ -1,6 +1,19 @@
|
|||||||
import torch, random, numpy as np
|
import torch
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import builtins
|
||||||
|
import fcntl
|
||||||
import process_group_manager as pgm
|
import process_group_manager as pgm
|
||||||
|
|
||||||
|
def print(*args, **kwargs):
|
||||||
|
""" solves multi-process interleaved print problem """
|
||||||
|
with open(__file__, "r") as fh:
|
||||||
|
fcntl.flock(fh, fcntl.LOCK_EX)
|
||||||
|
try:
|
||||||
|
builtins.print(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
fcntl.flock(fh, fcntl.LOCK_UN)
|
||||||
|
|
||||||
def set_all_seed(seed):
|
def set_all_seed(seed):
|
||||||
for module in [random, np.random]: module.seed(seed)
|
for module in [random, np.random]: module.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
@ -50,8 +63,8 @@ def display_parallelism_grid():
|
|||||||
|
|
||||||
output.append(f"=== Local Parallelism Configuration ===")
|
output.append(f"=== Local Parallelism Configuration ===")
|
||||||
output.append(pgm.process_group_manager.__str__())
|
output.append(pgm.process_group_manager.__str__())
|
||||||
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.dp_group_ids]}")
|
|
||||||
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.pp_group_ids]}")
|
|
||||||
output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.tp_group_ids]}")
|
output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.tp_group_ids]}")
|
||||||
|
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.pp_group_ids]}")
|
||||||
|
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.dp_group_ids]}")
|
||||||
|
|
||||||
print("\n".join(output))
|
print("\n".join(output))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user