renaming
This commit is contained in:
parent
3095ff4d4f
commit
1e229cae88
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
__pycache__
|
||||
*.pth
|
||||
.vscode/
|
||||
@ -5,7 +5,7 @@ import distributed.process_group_manager as pgm
|
||||
|
||||
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
|
||||
|
||||
def communicate(operation, device, dtype, tensor=None, shapes=None):
|
||||
def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
if operation == 'recv_forward':
|
||||
@ -31,7 +31,7 @@ def communicate(operation, device, dtype, tensor=None, shapes=None):
|
||||
if VERBOSE: STEP += 1
|
||||
return tensor if not is_send else None
|
||||
|
||||
def bidirectional_communicate(operation, send_tensor, recv_shapes, device, dtype):
|
||||
def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
is_fwd = (operation == 'send_fwd_recv_bwd')
|
||||
@ -43,4 +43,17 @@ def bidirectional_communicate(operation, send_tensor, recv_shapes, device, dtype
|
||||
[req.wait() for req in reqs]
|
||||
torch.cuda.synchronize()
|
||||
if VERBOSE: STEP += 1
|
||||
return recv_tensor
|
||||
return recv_tensor
|
||||
def all_reduce_loss_across_dp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
# Reduce the loss across all workers so that every rank has the updated loss value.
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group)
|
||||
reduced_loss /= pgm.process_group_manager.dp_world_size
|
||||
return reduced_loss.item()
|
||||
|
||||
def all_reduce_gradients_across_dp_cp_ranks(model):
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
# Average the gradients across all DP & CP ranks
|
||||
param.grad /= pgm.process_group_manager.cp_dp_world_size
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group)
|
||||
@ -8,7 +8,7 @@ from utils import set_all_seed
|
||||
import distributed.process_group_manager as pgm
|
||||
from distributed.process_group_manager import setup_process_group_manager
|
||||
from parallel.pipeline_parallel import PipelineParallel
|
||||
from distributed.distributed_primtives import communicate
|
||||
from distributed.distributed_primtives import pipeline_communicate
|
||||
from model import Llama
|
||||
|
||||
def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
|
||||
@ -24,14 +24,14 @@ def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
logits = torch.empty((batch_size, seq_len, int(config.vocab_size)), dtype=torch.float32, device=device)
|
||||
|
||||
recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32, device=device)
|
||||
recv_buffer = pipeline_communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32, device=device)
|
||||
|
||||
batch["hidden_states"] = None if pgm.process_group_manager.pp_is_first_stage else recv_buffer
|
||||
|
||||
output_tensor = model.forward(batch, device)
|
||||
|
||||
# Send output to the next stage.
|
||||
communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device)
|
||||
pipeline_communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device)
|
||||
|
||||
# Copy logits.
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
|
||||
@ -1,15 +1,8 @@
|
||||
import distributed.process_group_manager as pgm
|
||||
from distributed.distributed_primtives import communicate, bidirectional_communicate
|
||||
from distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate, all_reduce_loss_across_dp_ranks
|
||||
import torch, torch.nn as nn, torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
def reduce_loss_across_dp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
# Reduce the loss across all workers so that every rank has the updated loss value.
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group)
|
||||
reduced_loss /= pgm.process_group_manager.dp_world_size
|
||||
return reduced_loss.item()
|
||||
|
||||
class PipelineParallel(nn.Module):
|
||||
def __init__(self, model, config):
|
||||
super().__init__()
|
||||
@ -45,11 +38,11 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
||||
input_tensors, output_tensors = [], []
|
||||
|
||||
for _ in range(data_loader.num_local_micro_batches): # All forward passes
|
||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
batch = next(iter(data_loader))
|
||||
batch["hidden_states"] = input_tensor
|
||||
output_tensor = model.forward(batch, device)
|
||||
communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
|
||||
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
|
||||
|
||||
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
|
||||
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
|
||||
@ -60,12 +53,12 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
for _ in range(data_loader.num_local_micro_batches): # All backward passes
|
||||
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||
|
||||
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
|
||||
logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device)
|
||||
return logging_loss
|
||||
|
||||
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
||||
@ -85,33 +78,33 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
||||
return output_tensor
|
||||
|
||||
for _ in range(num_warmup_microbatches): # Warmup forward passes
|
||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
output_tensor = _forward_step(input_tensor)
|
||||
communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
|
||||
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
if num_microbatches_remaining > 0:
|
||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
|
||||
for i in range(num_microbatches_remaining): # 1F1B steady state
|
||||
output_tensor = _forward_step(input_tensor)
|
||||
output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
output_tensor_grad = bidirectional_pipeline_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
if i == num_microbatches_remaining - 1: # last iteration
|
||||
input_tensor = None
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||
else:
|
||||
input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
|
||||
for _ in range(num_warmup_microbatches): # Cooldown backward passes
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||
|
||||
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
|
||||
logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device)
|
||||
return logging_loss
|
||||
25
train.py
25
train.py
@ -10,13 +10,13 @@ from datasets import load_dataset
|
||||
import argparse
|
||||
|
||||
import distributed.process_group_manager as pgm
|
||||
from utils import set_all_seed, display_parallelism_grid, print
|
||||
from distributed.distributed_primtives import all_reduce_gradients_across_dp_cp_ranks
|
||||
from utils import set_all_seed, print, display_4D_parallelism_grid
|
||||
from distributed.process_group_manager import setup_process_group_manager
|
||||
from parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from parallel.data_parallel import DataParallel
|
||||
from parallel.context_parallel import ContextParallel
|
||||
from model import Llama
|
||||
from dataset import MicroBatchDataLoader
|
||||
import wandb
|
||||
|
||||
class MicroBatchDataLoader(DataLoader):
|
||||
@ -68,13 +68,6 @@ def train_step(model, data_loader, device):
|
||||
avg_loss = total_loss / data_loader.num_local_micro_batches
|
||||
return avg_loss
|
||||
|
||||
def all_reduce_grads_across_dp_cp_ranks():
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
# Average the gradients across all DP & CP ranks
|
||||
param.grad /= pgm.process_group_manager.cp_dp_world_size
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tp_size", type=int, default=1)
|
||||
@ -106,14 +99,13 @@ if __name__ == "__main__":
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
|
||||
dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}")
|
||||
|
||||
setup_process_group_manager(tp_size=args.tp_size, cp_size=args.cp_size, pp_size=args.pp_size, dp_size=args.dp_size)
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0:
|
||||
display_parallelism_grid()
|
||||
|
||||
# if pgm.process_group_manager.global_rank == 0:
|
||||
# display_4D_parallelism_grid()
|
||||
|
||||
set_all_seed(SEED)
|
||||
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
|
||||
dataset_name = "roneneldan/TinyStories"
|
||||
@ -145,6 +137,7 @@ if __name__ == "__main__":
|
||||
config=config,
|
||||
device=device,
|
||||
).to(device)
|
||||
|
||||
model.load_state_dict(torch.load("smollm.pth"))
|
||||
|
||||
if pgm.process_group_manager.cp_size > 1:
|
||||
@ -156,6 +149,9 @@ if __name__ == "__main__":
|
||||
if pgm.process_group_manager.dp_world_size > 1:
|
||||
model = DataParallel(model, config).to(device)
|
||||
|
||||
# if pgm.process_group_manager.tp_world_size > 1:
|
||||
# model = TensorParallel(model, config).to(device)
|
||||
|
||||
model.train()
|
||||
|
||||
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES)
|
||||
@ -181,11 +177,12 @@ if __name__ == "__main__":
|
||||
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device)
|
||||
# loss = train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device)
|
||||
else:
|
||||
loss = train_step(model, data_loader, device)
|
||||
|
||||
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
|
||||
all_reduce_grads_across_dp_cp_ranks()
|
||||
all_reduce_gradients_across_dp_cp_ranks(model)
|
||||
|
||||
optimizer.step()
|
||||
trained_tokens += tokens_per_step
|
||||
|
||||
159
utils.py
159
utils.py
@ -19,103 +19,62 @@ def set_all_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def display_parallelism_grid():
|
||||
def _create_gpu_box(gpu_num, tp, cp, pp, dp):
|
||||
return [
|
||||
f" GPU {gpu_num:<2} ",
|
||||
f" +----------+",
|
||||
f" | tp{tp} cp{cp} |",
|
||||
f" | pp{pp} dp{dp} |",
|
||||
f" +----------+"
|
||||
]
|
||||
|
||||
def _create_row(start_gpu, tp_size, cp, pp, dp):
|
||||
boxes = [_create_gpu_box(start_gpu + i, i, cp, pp, dp) for i in range(tp_size)]
|
||||
return [" ".join(row) for row in zip(*boxes)]
|
||||
|
||||
def _add_pp_label(output):
|
||||
output.append(" | ")
|
||||
output.append(" PP| ")
|
||||
output.append(" | ")
|
||||
|
||||
def _add_cp_label(output):
|
||||
output.append(" | CP")
|
||||
|
||||
def _add_vertical_separator(output):
|
||||
output.append(" | ")
|
||||
output.append(" | |")
|
||||
|
||||
def _add_vertical_arrow(output):
|
||||
output.append(" | v")
|
||||
|
||||
def _add_horizontal_separator(output):
|
||||
output.append("-" * 86)
|
||||
|
||||
def _create_tp_arrows_and_labels(tp_group_width):
|
||||
tp_arrow = "-" * (tp_group_width - 4) + ">"
|
||||
tp_label = f"{'TP':^{tp_group_width}}"
|
||||
tp_arrows = f" {tp_arrow:<{tp_group_width}} {tp_arrow}"
|
||||
tp_labels = f" {tp_label:<{tp_group_width}} {tp_label}"
|
||||
return tp_arrows, tp_labels
|
||||
|
||||
def _create_dp_arrow_and_label(total_tp_width):
|
||||
dp_arrow = "-" * (total_tp_width - 6) + ">"
|
||||
dp_label = f"{'DP':^{total_tp_width}}"
|
||||
return f" {dp_arrow}", f" {dp_label}"
|
||||
|
||||
output = []
|
||||
tp_size = pgm.process_group_manager.tp_size
|
||||
cp_size = pgm.process_group_manager.cp_size
|
||||
pp_size = pgm.process_group_manager.pp_size
|
||||
dp_size = pgm.process_group_manager.dp_size
|
||||
|
||||
output.append("=== Global Parallelism Configuration ===")
|
||||
output.append(f"TP Size: {tp_size}, CP_size: {cp_size}, PP Size: {pp_size}, DP Size: {dp_size}")
|
||||
output.append("")
|
||||
|
||||
for dp in range(0, dp_size, 2):
|
||||
output.append(" | ")
|
||||
|
||||
for pp in range(pp_size):
|
||||
if pp == pp_size // 2:
|
||||
_add_pp_label(output)
|
||||
|
||||
_add_vertical_separator(output)
|
||||
|
||||
for cp in range(cp_size):
|
||||
left_start_gpu = dp * (tp_size * cp_size * pp_size) + pp * (tp_size * cp_size) + cp * tp_size
|
||||
left_row = _create_row(left_start_gpu, tp_size, cp, pp, dp)
|
||||
|
||||
if dp + 1 < dp_size:
|
||||
right_start_gpu = (dp+1) * (tp_size * cp_size * pp_size) + pp * (tp_size * cp_size) + cp * tp_size
|
||||
right_row = _create_row(right_start_gpu, tp_size, cp, pp, dp+1)
|
||||
for l, r in zip(left_row, right_row):
|
||||
output.append(f" | | {l:<33} {r}")
|
||||
else:
|
||||
for l in left_row:
|
||||
output.append(f" | | {l}")
|
||||
|
||||
if cp < cp_size - 1:
|
||||
_add_cp_label(output)
|
||||
output.append(" | |")
|
||||
|
||||
_add_vertical_arrow(output)
|
||||
|
||||
if pp < pp_size - 1:
|
||||
output.append(" | ")
|
||||
|
||||
output.append(" | ")
|
||||
output.append(" v ")
|
||||
|
||||
if dp + 2 < dp_size:
|
||||
_add_horizontal_separator(output)
|
||||
|
||||
tp_group_width = tp_size * 13 - 1
|
||||
total_tp_width = tp_group_width * 2 + 18
|
||||
|
||||
tp_arrows, tp_labels = _create_tp_arrows_and_labels(tp_group_width)
|
||||
dp_arrow, dp_label = _create_dp_arrow_and_label(total_tp_width)
|
||||
|
||||
output.extend(["", tp_arrows, tp_labels, "", dp_arrow, dp_label])
|
||||
|
||||
print("\n".join(output))
|
||||
## def display_4D_parallelism_grid():
|
||||
# #TODO(fmom): fix me
|
||||
# #TODO(fmom): add color to distinguish between different parallelism groups
|
||||
# def create_gpu_box(gpu_num, tp, cp, pp):
|
||||
# return [
|
||||
# f"+------+",
|
||||
# f"|GPU:{gpu_num:<2d}|",
|
||||
# f"| TP:{tp:d} |",
|
||||
# f"| CP:{cp:d} |",
|
||||
# f"| PP:{pp:d} |",
|
||||
# f"+------+"
|
||||
# ]
|
||||
#
|
||||
# def create_node(start_gpu, tp_size, cp_size, pp_size, node_index):
|
||||
# boxes = []
|
||||
# for i in range(8): # 8 GPUs per node
|
||||
# gpu_num = start_gpu + i
|
||||
# tp = gpu_num % tp_size
|
||||
# cp = (gpu_num // tp_size) % cp_size
|
||||
# pp = (gpu_num // (tp_size * cp_size)) % pp_size
|
||||
# boxes.append(create_gpu_box(gpu_num, tp, cp, pp))
|
||||
# return [' '.join(row) for row in zip(*boxes)]
|
||||
#
|
||||
# def create_dp_box(replica_output):
|
||||
# width = len(replica_output[0]) + 4
|
||||
# top_bottom = f"+{'-' * (width - 2)}+"
|
||||
# return [top_bottom] + [f"| {line} |" for line in replica_output] + [top_bottom]
|
||||
#
|
||||
# tp_size = pgm.process_group_manager.tp_size
|
||||
# cp_size = pgm.process_group_manager.cp_size
|
||||
# pp_size = pgm.process_group_manager.pp_size
|
||||
# dp_size = pgm.process_group_manager.dp_size
|
||||
# total_gpus_per_replica = tp_size * cp_size * pp_size
|
||||
# num_nodes_per_replica = (total_gpus_per_replica + 7) // 8 # Round up to nearest whole node
|
||||
#
|
||||
# output = []
|
||||
# output.append("=== Simplified Parallelism Configuration ===")
|
||||
# output.append(f"TP Size: {tp_size}, CP Size: {cp_size}, PP Size: {pp_size}, DP Size: {dp_size}")
|
||||
# output.append(f"Total GPUs for one replica: {total_gpus_per_replica}")
|
||||
# output.append(f"Number of nodes per replica: {num_nodes_per_replica} (8 GPUs per node)")
|
||||
# output.append(f"Total GPUs: {total_gpus_per_replica * dp_size}")
|
||||
# output.append(f"Total nodes: {num_nodes_per_replica * dp_size}")
|
||||
# output.append("")
|
||||
#
|
||||
# for dp in range(dp_size):
|
||||
# replica_output = []
|
||||
# for node in range(num_nodes_per_replica):
|
||||
# start_gpu = (dp * total_gpus_per_replica) + (node * 8)
|
||||
# node_output = create_node(start_gpu, tp_size, cp_size, pp_size, node)
|
||||
# replica_output.append(f"Node {dp * num_nodes_per_replica + node}:")
|
||||
# replica_output.extend(node_output)
|
||||
# replica_output.append("")
|
||||
#
|
||||
# dp_box = create_dp_box(replica_output)
|
||||
# output.append(f"Data Parallel Group {dp}:")
|
||||
# output.extend(dp_box)
|
||||
# output.append("")
|
||||
#
|
||||
# print("\n".join(output))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user