This commit is contained in:
ferdinand.mom 2024-10-14 09:26:31 +00:00
parent 3095ff4d4f
commit 1e229cae88
6 changed files with 107 additions and 142 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
__pycache__
*.pth
.vscode/

View File

@ -5,7 +5,7 @@ import distributed.process_group_manager as pgm
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
def communicate(operation, device, dtype, tensor=None, shapes=None):
def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None):
global STEP
global VERBOSE
if operation == 'recv_forward':
@ -31,7 +31,7 @@ def communicate(operation, device, dtype, tensor=None, shapes=None):
if VERBOSE: STEP += 1
return tensor if not is_send else None
def bidirectional_communicate(operation, send_tensor, recv_shapes, device, dtype):
def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype):
global STEP
global VERBOSE
is_fwd = (operation == 'send_fwd_recv_bwd')
@ -43,4 +43,17 @@ def bidirectional_communicate(operation, send_tensor, recv_shapes, device, dtype
[req.wait() for req in reqs]
torch.cuda.synchronize()
if VERBOSE: STEP += 1
return recv_tensor
return recv_tensor
def all_reduce_loss_across_dp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
# Reduce the loss across all workers so that every rank has the updated loss value.
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group)
reduced_loss /= pgm.process_group_manager.dp_world_size
return reduced_loss.item()
def all_reduce_gradients_across_dp_cp_ranks(model):
for param in model.parameters():
if param.grad is not None:
# Average the gradients across all DP & CP ranks
param.grad /= pgm.process_group_manager.cp_dp_world_size
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group)

View File

@ -8,7 +8,7 @@ from utils import set_all_seed
import distributed.process_group_manager as pgm
from distributed.process_group_manager import setup_process_group_manager
from parallel.pipeline_parallel import PipelineParallel
from distributed.distributed_primtives import communicate
from distributed.distributed_primtives import pipeline_communicate
from model import Llama
def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
@ -24,14 +24,14 @@ def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
if pgm.process_group_manager.pp_is_last_stage:
logits = torch.empty((batch_size, seq_len, int(config.vocab_size)), dtype=torch.float32, device=device)
recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32, device=device)
recv_buffer = pipeline_communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32, device=device)
batch["hidden_states"] = None if pgm.process_group_manager.pp_is_first_stage else recv_buffer
output_tensor = model.forward(batch, device)
# Send output to the next stage.
communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device)
pipeline_communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device)
# Copy logits.
if pgm.process_group_manager.pp_is_last_stage:

View File

@ -1,15 +1,8 @@
import distributed.process_group_manager as pgm
from distributed.distributed_primtives import communicate, bidirectional_communicate
from distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate, all_reduce_loss_across_dp_ranks
import torch, torch.nn as nn, torch.nn.functional as F
import torch.distributed as dist
def reduce_loss_across_dp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
# Reduce the loss across all workers so that every rank has the updated loss value.
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group)
reduced_loss /= pgm.process_group_manager.dp_world_size
return reduced_loss.item()
class PipelineParallel(nn.Module):
def __init__(self, model, config):
super().__init__()
@ -45,11 +38,11 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
input_tensors, output_tensors = [], []
for _ in range(data_loader.num_local_micro_batches): # All forward passes
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
batch = next(iter(data_loader))
batch["hidden_states"] = input_tensor
output_tensor = model.forward(batch, device)
communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
@ -60,12 +53,12 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
output_tensors.append(output_tensor)
for _ in range(data_loader.num_local_micro_batches): # All backward passes
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device)
return logging_loss
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
@ -85,33 +78,33 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
return output_tensor
for _ in range(num_warmup_microbatches): # Warmup forward passes
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
output_tensor = _forward_step(input_tensor)
communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if num_microbatches_remaining > 0:
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
for i in range(num_microbatches_remaining): # 1F1B steady state
output_tensor = _forward_step(input_tensor)
output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
output_tensor_grad = bidirectional_pipeline_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
if i == num_microbatches_remaining - 1: # last iteration
input_tensor = None
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
else:
input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
for _ in range(num_warmup_microbatches): # Cooldown backward passes
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device)
return logging_loss

View File

@ -10,13 +10,13 @@ from datasets import load_dataset
import argparse
import distributed.process_group_manager as pgm
from utils import set_all_seed, display_parallelism_grid, print
from distributed.distributed_primtives import all_reduce_gradients_across_dp_cp_ranks
from utils import set_all_seed, print, display_4D_parallelism_grid
from distributed.process_group_manager import setup_process_group_manager
from parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from parallel.data_parallel import DataParallel
from parallel.context_parallel import ContextParallel
from model import Llama
from dataset import MicroBatchDataLoader
import wandb
class MicroBatchDataLoader(DataLoader):
@ -68,13 +68,6 @@ def train_step(model, data_loader, device):
avg_loss = total_loss / data_loader.num_local_micro_batches
return avg_loss
def all_reduce_grads_across_dp_cp_ranks():
for param in model.parameters():
if param.grad is not None:
# Average the gradients across all DP & CP ranks
param.grad /= pgm.process_group_manager.cp_dp_world_size
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tp_size", type=int, default=1)
@ -106,14 +99,13 @@ if __name__ == "__main__":
else:
device = torch.device("cpu")
dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}")
setup_process_group_manager(tp_size=args.tp_size, cp_size=args.cp_size, pp_size=args.pp_size, dp_size=args.dp_size)
if pgm.process_group_manager.global_rank == 0:
display_parallelism_grid()
# if pgm.process_group_manager.global_rank == 0:
# display_4D_parallelism_grid()
set_all_seed(SEED)
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
dataset_name = "roneneldan/TinyStories"
@ -145,6 +137,7 @@ if __name__ == "__main__":
config=config,
device=device,
).to(device)
model.load_state_dict(torch.load("smollm.pth"))
if pgm.process_group_manager.cp_size > 1:
@ -156,6 +149,9 @@ if __name__ == "__main__":
if pgm.process_group_manager.dp_world_size > 1:
model = DataParallel(model, config).to(device)
# if pgm.process_group_manager.tp_world_size > 1:
# model = TensorParallel(model, config).to(device)
model.train()
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES)
@ -181,11 +177,12 @@ if __name__ == "__main__":
if pgm.process_group_manager.pp_world_size > 1:
loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device)
# loss = train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device)
else:
loss = train_step(model, data_loader, device)
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
all_reduce_grads_across_dp_cp_ranks()
all_reduce_gradients_across_dp_cp_ranks(model)
optimizer.step()
trained_tokens += tokens_per_step

159
utils.py
View File

@ -19,103 +19,62 @@ def set_all_seed(seed):
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
def display_parallelism_grid():
def _create_gpu_box(gpu_num, tp, cp, pp, dp):
return [
f" GPU {gpu_num:<2} ",
f" +----------+",
f" | tp{tp} cp{cp} |",
f" | pp{pp} dp{dp} |",
f" +----------+"
]
def _create_row(start_gpu, tp_size, cp, pp, dp):
boxes = [_create_gpu_box(start_gpu + i, i, cp, pp, dp) for i in range(tp_size)]
return [" ".join(row) for row in zip(*boxes)]
def _add_pp_label(output):
output.append(" | ")
output.append(" PP| ")
output.append(" | ")
def _add_cp_label(output):
output.append(" | CP")
def _add_vertical_separator(output):
output.append(" | ")
output.append(" | |")
def _add_vertical_arrow(output):
output.append(" | v")
def _add_horizontal_separator(output):
output.append("-" * 86)
def _create_tp_arrows_and_labels(tp_group_width):
tp_arrow = "-" * (tp_group_width - 4) + ">"
tp_label = f"{'TP':^{tp_group_width}}"
tp_arrows = f" {tp_arrow:<{tp_group_width}} {tp_arrow}"
tp_labels = f" {tp_label:<{tp_group_width}} {tp_label}"
return tp_arrows, tp_labels
def _create_dp_arrow_and_label(total_tp_width):
dp_arrow = "-" * (total_tp_width - 6) + ">"
dp_label = f"{'DP':^{total_tp_width}}"
return f" {dp_arrow}", f" {dp_label}"
output = []
tp_size = pgm.process_group_manager.tp_size
cp_size = pgm.process_group_manager.cp_size
pp_size = pgm.process_group_manager.pp_size
dp_size = pgm.process_group_manager.dp_size
output.append("=== Global Parallelism Configuration ===")
output.append(f"TP Size: {tp_size}, CP_size: {cp_size}, PP Size: {pp_size}, DP Size: {dp_size}")
output.append("")
for dp in range(0, dp_size, 2):
output.append(" | ")
for pp in range(pp_size):
if pp == pp_size // 2:
_add_pp_label(output)
_add_vertical_separator(output)
for cp in range(cp_size):
left_start_gpu = dp * (tp_size * cp_size * pp_size) + pp * (tp_size * cp_size) + cp * tp_size
left_row = _create_row(left_start_gpu, tp_size, cp, pp, dp)
if dp + 1 < dp_size:
right_start_gpu = (dp+1) * (tp_size * cp_size * pp_size) + pp * (tp_size * cp_size) + cp * tp_size
right_row = _create_row(right_start_gpu, tp_size, cp, pp, dp+1)
for l, r in zip(left_row, right_row):
output.append(f" | | {l:<33} {r}")
else:
for l in left_row:
output.append(f" | | {l}")
if cp < cp_size - 1:
_add_cp_label(output)
output.append(" | |")
_add_vertical_arrow(output)
if pp < pp_size - 1:
output.append(" | ")
output.append(" | ")
output.append(" v ")
if dp + 2 < dp_size:
_add_horizontal_separator(output)
tp_group_width = tp_size * 13 - 1
total_tp_width = tp_group_width * 2 + 18
tp_arrows, tp_labels = _create_tp_arrows_and_labels(tp_group_width)
dp_arrow, dp_label = _create_dp_arrow_and_label(total_tp_width)
output.extend(["", tp_arrows, tp_labels, "", dp_arrow, dp_label])
print("\n".join(output))
## def display_4D_parallelism_grid():
# #TODO(fmom): fix me
# #TODO(fmom): add color to distinguish between different parallelism groups
# def create_gpu_box(gpu_num, tp, cp, pp):
# return [
# f"+------+",
# f"|GPU:{gpu_num:<2d}|",
# f"| TP:{tp:d} |",
# f"| CP:{cp:d} |",
# f"| PP:{pp:d} |",
# f"+------+"
# ]
#
# def create_node(start_gpu, tp_size, cp_size, pp_size, node_index):
# boxes = []
# for i in range(8): # 8 GPUs per node
# gpu_num = start_gpu + i
# tp = gpu_num % tp_size
# cp = (gpu_num // tp_size) % cp_size
# pp = (gpu_num // (tp_size * cp_size)) % pp_size
# boxes.append(create_gpu_box(gpu_num, tp, cp, pp))
# return [' '.join(row) for row in zip(*boxes)]
#
# def create_dp_box(replica_output):
# width = len(replica_output[0]) + 4
# top_bottom = f"+{'-' * (width - 2)}+"
# return [top_bottom] + [f"| {line} |" for line in replica_output] + [top_bottom]
#
# tp_size = pgm.process_group_manager.tp_size
# cp_size = pgm.process_group_manager.cp_size
# pp_size = pgm.process_group_manager.pp_size
# dp_size = pgm.process_group_manager.dp_size
# total_gpus_per_replica = tp_size * cp_size * pp_size
# num_nodes_per_replica = (total_gpus_per_replica + 7) // 8 # Round up to nearest whole node
#
# output = []
# output.append("=== Simplified Parallelism Configuration ===")
# output.append(f"TP Size: {tp_size}, CP Size: {cp_size}, PP Size: {pp_size}, DP Size: {dp_size}")
# output.append(f"Total GPUs for one replica: {total_gpus_per_replica}")
# output.append(f"Number of nodes per replica: {num_nodes_per_replica} (8 GPUs per node)")
# output.append(f"Total GPUs: {total_gpus_per_replica * dp_size}")
# output.append(f"Total nodes: {num_nodes_per_replica * dp_size}")
# output.append("")
#
# for dp in range(dp_size):
# replica_output = []
# for node in range(num_nodes_per_replica):
# start_gpu = (dp * total_gpus_per_replica) + (node * 8)
# node_output = create_node(start_gpu, tp_size, cp_size, pp_size, node)
# replica_output.append(f"Node {dp * num_nodes_per_replica + node}:")
# replica_output.extend(node_output)
# replica_output.append("")
#
# dp_box = create_dp_box(replica_output)
# output.append(f"Data Parallel Group {dp}:")
# output.extend(dp_box)
# output.append("")
#
# print("\n".join(output))