fixing socket bug by using dist.new_subgroups_by_enumeration instead

This commit is contained in:
ferdinand.mom 2024-09-24 13:43:22 +00:00
parent 7a57407c54
commit 7ba1383ebb
4 changed files with 66 additions and 66 deletions

View File

@ -13,18 +13,19 @@ class ParallelContext:
self.dp_size = dp_size
assert self.world_size == self.tp_size * self.pp_size * self.dp_size, f"World size ({self.world_size}) != TP ({self.tp_size}) * PP ({self.pp_size}) * DP ({self.dp_size})"
self.grid = torch.arange(self.world_size).view(self.pp_size, self.dp_size, self.tp_size).permute(2, 0, 1)
self.grid = torch.arange(self.world_size).view(self.tp_size, self.pp_size, self.dp_size,) # TP * PP * DP grid
# Find the position of the current process in the grid
self.tp_rank, self.pp_rank, self.dp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist()
# Process group creation
self.dp_group = dist.new_subgroups_by_enumeration([self.grid[i, j, :].tolist() for i in range(tp_size) for j in range(pp_size)])[0]
self.tp_group = dist.new_subgroups_by_enumeration([self.grid[:, i, j].tolist() for i in range(pp_size) for j in range(dp_size)])[0]
self.pp_group = dist.new_subgroups_by_enumeration([self.grid[i, :, j].tolist() for i in range(tp_size) for j in range(dp_size)])[0]
self.tp_group_ids = self.grid[:, self.pp_rank, self.dp_rank].tolist()
self.pp_group_ids = self.grid[self.tp_rank, :, self.dp_rank].tolist()
self.dp_group_ids = self.grid[self.tp_rank, self.pp_rank, :].tolist()
self.tp_group = dist.new_group(self.tp_group_ids)
self.pp_group = dist.new_group(self.pp_group_ids)
self.dp_group = dist.new_group(self.dp_group_ids)
# Tensor parallelism
self.tp_first_rank = self.tp_group_ids[0]
self.tp_last_rank = self.tp_group_ids[-1]
@ -47,56 +48,6 @@ class ParallelContext:
def __str__(self):
return f"DP({self.dp_size})-PP({self.pp_size})-TP({self.tp_size})-Rank({self.global_rank})"
def display_parallelism_grid(self):
def _create_box(content):
return f" {content:^3} "
def _create_row(row):
return "|" + "|".join(_create_box(f"g{num:02d}") for num in row) + "|"
def _create_border(width):
return "+" + "-" * (width - 2) + "+"
def _create_pp_line(width, pp_size):
box_width = (width - pp_size + 1) // pp_size
return " ".join("PP".center(box_width) for _ in range(pp_size))
output = []
sample_row = _create_row(self.grid[0, :, 0])
row_width = len(sample_row)
border = _create_border(row_width)
output.append(f"=== Global Parallelism Configuration ===")
output.append(f"DP Size: {self.dp_size}, PP Size: {self.pp_size}, TP Size: {self.grid.shape[0]}")
output.append("") # Top spacing
for dp in range(self.dp_size):
output.append(f"DP {dp}:")
output.append(f"{'':>8}{border}")
for tp in range(self.grid.shape[0]):
if tp == 0:
output.append(f"{'TP':>7} {_create_row(self.grid[tp, :, dp])}")
else:
output.append(f"{'':8}{border}")
output.append(f"{'TP':>7} {_create_row(self.grid[tp, :, dp])}")
output.append(f"{'':8}{border}")
if self.pp_size > 1:
output.append(f"{'':>7}{_create_pp_line(row_width, self.pp_size)}")
output.append("") # Spacing between DP blocks
output.append("") # Bottom spacing
output.append(f"=== Local Parallelism Configuration ===")
output.append(self.__str__())
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in self.dp_group_ids]}")
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in self.pp_group_ids]}")
output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in self.tp_group_ids]}")
print("\n".join(output))
def setup_parallel_context(tp_size, pp_size, dp_size):
global parallel_context
parallel_context = ParallelContext(tp_size, pp_size, dp_size)

View File

@ -7,9 +7,7 @@ import torch.distributed as dist
def reduce_loss_across_dp_ranks(loss, device):
# Reduce the loss across DP workers.
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pc.parallel_context.dp_group)
# Average the loss across DP workers.
reduced_loss /= pc.parallel_context.world_size
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pc.parallel_context.dp_group)
return reduced_loss.item()
class PipelineParallel(nn.Module):
@ -116,7 +114,6 @@ def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device):
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
communicate(operation='send_backward', tensor=input_tensor_grad)
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
return logging_loss

View File

@ -8,7 +8,7 @@ from transformers import AutoTokenizer
import argparse
import parallel_context as pc
from utils import set_all_seed
from utils import set_all_seed, display_parallelism_grid
from parallel_context import setup_parallel_context
from pipeline_parallel import pipeline_parallel_1f1b, pipeline_parallel_afab, PipelineParallel
@ -51,7 +51,7 @@ if __name__ == "__main__":
setup_parallel_context(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size)
if pc.parallel_context.global_rank == local_rank:
pc.parallel_context.display_parallelism_grid()
display_parallelism_grid()
set_all_seed(seed=42)
model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device)
@ -61,14 +61,15 @@ if __name__ == "__main__":
trained_tokens, step = 0, 0
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN
#TODO: Profile memory
#TODO: hanging
while trained_tokens < MAX_TOKENS:
optimizer.zero_grad()
loss = pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device)
loss = pipeline_parallel_afab(model, data_loader, tensor_shapes, device)
optimizer.step()
trained_tokens += tokens_per_step
step += 1
if pc.parallel_context.pp_is_last_stage and pc.parallel_context.global_rank == pc.parallel_context.dp_first_rank:
#NOTE(fmom): change later to log on rank 0 (g00) everytime ?
if pc.parallel_context.pp_is_last_stage and pc.parallel_context.global_rank == pc.parallel_context.tp_first_rank and pc.parallel_context.global_rank == pc.parallel_context.dp_first_rank:
print(f"[rank {pc.parallel_context.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
dist.destroy_process_group()

View File

@ -1,6 +1,57 @@
import torch, random, numpy as np
import parallel_context as pc
def set_all_seed(seed):
for module in [random, np.random]: module.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
def display_parallelism_grid():
def _create_box(content):
return f" {content:^3} "
def _create_row(row):
return "|" + "|".join(_create_box(f"g{num:02d}") for num in row) + "|"
def _create_border(width):
return "+" + "-" * (width - 2) + "+"
def _create_pp_line(width, pp_size):
box_width = (width - pp_size + 1) // pp_size
return " ".join("PP".center(box_width) for _ in range(pp_size))
output = []
sample_row = _create_row(pc.parallel_context.grid[0, :, 0])
row_width = len(sample_row)
border = _create_border(row_width)
output.append(f"=== Global Parallelism Configuration ===")
output.append(f"DP Size: {pc.parallel_context.dp_size}, PP Size: {pc.parallel_context.pp_size}, TP Size: {pc.parallel_context.grid.shape[0]}")
output.append("") # Top spacing
for dp in range(pc.parallel_context.dp_size):
output.append(f"DP {dp}:")
output.append(f"{'':>8}{border}")
for tp in range(pc.parallel_context.grid.shape[0]):
if tp == 0:
output.append(f"{'TP':>7} {_create_row(pc.parallel_context.grid[tp, :, dp])}")
else:
output.append(f"{'':8}{border}")
output.append(f"{'TP':>7} {_create_row(pc.parallel_context.grid[tp, :, dp])}")
output.append(f"{'':8}{border}")
if pc.parallel_context.pp_size > 1:
output.append(f"{'':>7}{_create_pp_line(row_width, pc.parallel_context.pp_size)}")
output.append("") # Spacing between DP blocks
output.append("") # Bottom spacing
output.append(f"=== Local Parallelism Configuration ===")
output.append(pc.parallel_context.__str__())
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.dp_group_ids]}")
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.pp_group_ids]}")
output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.tp_group_ids]}")
print("\n".join(output))