fixing socket bug by using dist.new_subgroups_by_enumeration instead
This commit is contained in:
parent
7a57407c54
commit
7ba1383ebb
@ -13,18 +13,19 @@ class ParallelContext:
|
||||
self.dp_size = dp_size
|
||||
assert self.world_size == self.tp_size * self.pp_size * self.dp_size, f"World size ({self.world_size}) != TP ({self.tp_size}) * PP ({self.pp_size}) * DP ({self.dp_size})"
|
||||
|
||||
self.grid = torch.arange(self.world_size).view(self.pp_size, self.dp_size, self.tp_size).permute(2, 0, 1)
|
||||
self.grid = torch.arange(self.world_size).view(self.tp_size, self.pp_size, self.dp_size,) # TP * PP * DP grid
|
||||
# Find the position of the current process in the grid
|
||||
self.tp_rank, self.pp_rank, self.dp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist()
|
||||
|
||||
# Process group creation
|
||||
self.dp_group = dist.new_subgroups_by_enumeration([self.grid[i, j, :].tolist() for i in range(tp_size) for j in range(pp_size)])[0]
|
||||
self.tp_group = dist.new_subgroups_by_enumeration([self.grid[:, i, j].tolist() for i in range(pp_size) for j in range(dp_size)])[0]
|
||||
self.pp_group = dist.new_subgroups_by_enumeration([self.grid[i, :, j].tolist() for i in range(tp_size) for j in range(dp_size)])[0]
|
||||
|
||||
self.tp_group_ids = self.grid[:, self.pp_rank, self.dp_rank].tolist()
|
||||
self.pp_group_ids = self.grid[self.tp_rank, :, self.dp_rank].tolist()
|
||||
self.dp_group_ids = self.grid[self.tp_rank, self.pp_rank, :].tolist()
|
||||
self.tp_group = dist.new_group(self.tp_group_ids)
|
||||
self.pp_group = dist.new_group(self.pp_group_ids)
|
||||
self.dp_group = dist.new_group(self.dp_group_ids)
|
||||
|
||||
|
||||
# Tensor parallelism
|
||||
self.tp_first_rank = self.tp_group_ids[0]
|
||||
self.tp_last_rank = self.tp_group_ids[-1]
|
||||
@ -47,56 +48,6 @@ class ParallelContext:
|
||||
def __str__(self):
|
||||
return f"DP({self.dp_size})-PP({self.pp_size})-TP({self.tp_size})-Rank({self.global_rank})"
|
||||
|
||||
def display_parallelism_grid(self):
|
||||
def _create_box(content):
|
||||
return f" {content:^3} "
|
||||
|
||||
def _create_row(row):
|
||||
return "|" + "|".join(_create_box(f"g{num:02d}") for num in row) + "|"
|
||||
|
||||
def _create_border(width):
|
||||
return "+" + "-" * (width - 2) + "+"
|
||||
|
||||
def _create_pp_line(width, pp_size):
|
||||
box_width = (width - pp_size + 1) // pp_size
|
||||
return " ".join("PP".center(box_width) for _ in range(pp_size))
|
||||
|
||||
output = []
|
||||
sample_row = _create_row(self.grid[0, :, 0])
|
||||
row_width = len(sample_row)
|
||||
border = _create_border(row_width)
|
||||
|
||||
output.append(f"=== Global Parallelism Configuration ===")
|
||||
output.append(f"DP Size: {self.dp_size}, PP Size: {self.pp_size}, TP Size: {self.grid.shape[0]}")
|
||||
output.append("") # Top spacing
|
||||
|
||||
for dp in range(self.dp_size):
|
||||
output.append(f"DP {dp}:")
|
||||
output.append(f"{'':>8}{border}")
|
||||
|
||||
for tp in range(self.grid.shape[0]):
|
||||
if tp == 0:
|
||||
output.append(f"{'TP':>7} {_create_row(self.grid[tp, :, dp])}")
|
||||
else:
|
||||
output.append(f"{'':8}{border}")
|
||||
output.append(f"{'TP':>7} {_create_row(self.grid[tp, :, dp])}")
|
||||
|
||||
output.append(f"{'':8}{border}")
|
||||
if self.pp_size > 1:
|
||||
output.append(f"{'':>7}{_create_pp_line(row_width, self.pp_size)}")
|
||||
|
||||
output.append("") # Spacing between DP blocks
|
||||
|
||||
output.append("") # Bottom spacing
|
||||
|
||||
output.append(f"=== Local Parallelism Configuration ===")
|
||||
output.append(self.__str__())
|
||||
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in self.dp_group_ids]}")
|
||||
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in self.pp_group_ids]}")
|
||||
output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in self.tp_group_ids]}")
|
||||
|
||||
print("\n".join(output))
|
||||
|
||||
def setup_parallel_context(tp_size, pp_size, dp_size):
|
||||
global parallel_context
|
||||
parallel_context = ParallelContext(tp_size, pp_size, dp_size)
|
||||
@ -7,9 +7,7 @@ import torch.distributed as dist
|
||||
def reduce_loss_across_dp_ranks(loss, device):
|
||||
# Reduce the loss across DP workers.
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pc.parallel_context.dp_group)
|
||||
# Average the loss across DP workers.
|
||||
reduced_loss /= pc.parallel_context.world_size
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pc.parallel_context.dp_group)
|
||||
return reduced_loss.item()
|
||||
|
||||
class PipelineParallel(nn.Module):
|
||||
@ -116,7 +114,6 @@ def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device):
|
||||
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad)
|
||||
|
||||
|
||||
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
|
||||
return logging_loss
|
||||
15
train.py
15
train.py
@ -8,7 +8,7 @@ from transformers import AutoTokenizer
|
||||
import argparse
|
||||
|
||||
import parallel_context as pc
|
||||
from utils import set_all_seed
|
||||
from utils import set_all_seed, display_parallelism_grid
|
||||
from parallel_context import setup_parallel_context
|
||||
from pipeline_parallel import pipeline_parallel_1f1b, pipeline_parallel_afab, PipelineParallel
|
||||
|
||||
@ -51,7 +51,7 @@ if __name__ == "__main__":
|
||||
setup_parallel_context(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size)
|
||||
|
||||
if pc.parallel_context.global_rank == local_rank:
|
||||
pc.parallel_context.display_parallelism_grid()
|
||||
display_parallelism_grid()
|
||||
|
||||
set_all_seed(seed=42)
|
||||
model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device)
|
||||
@ -61,14 +61,15 @@ if __name__ == "__main__":
|
||||
trained_tokens, step = 0, 0
|
||||
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN
|
||||
|
||||
#TODO: Profile memory
|
||||
#TODO: hanging
|
||||
|
||||
while trained_tokens < MAX_TOKENS:
|
||||
optimizer.zero_grad()
|
||||
loss = pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device)
|
||||
loss = pipeline_parallel_afab(model, data_loader, tensor_shapes, device)
|
||||
optimizer.step()
|
||||
trained_tokens += tokens_per_step
|
||||
step += 1
|
||||
if pc.parallel_context.pp_is_last_stage and pc.parallel_context.global_rank == pc.parallel_context.dp_first_rank:
|
||||
|
||||
#NOTE(fmom): change later to log on rank 0 (g00) everytime ?
|
||||
if pc.parallel_context.pp_is_last_stage and pc.parallel_context.global_rank == pc.parallel_context.tp_first_rank and pc.parallel_context.global_rank == pc.parallel_context.dp_first_rank:
|
||||
print(f"[rank {pc.parallel_context.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
51
utils.py
51
utils.py
@ -1,6 +1,57 @@
|
||||
import torch, random, numpy as np
|
||||
import parallel_context as pc
|
||||
|
||||
def set_all_seed(seed):
|
||||
for module in [random, np.random]: module.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def display_parallelism_grid():
|
||||
def _create_box(content):
|
||||
return f" {content:^3} "
|
||||
|
||||
def _create_row(row):
|
||||
return "|" + "|".join(_create_box(f"g{num:02d}") for num in row) + "|"
|
||||
|
||||
def _create_border(width):
|
||||
return "+" + "-" * (width - 2) + "+"
|
||||
|
||||
def _create_pp_line(width, pp_size):
|
||||
box_width = (width - pp_size + 1) // pp_size
|
||||
return " ".join("PP".center(box_width) for _ in range(pp_size))
|
||||
|
||||
output = []
|
||||
sample_row = _create_row(pc.parallel_context.grid[0, :, 0])
|
||||
row_width = len(sample_row)
|
||||
border = _create_border(row_width)
|
||||
|
||||
output.append(f"=== Global Parallelism Configuration ===")
|
||||
output.append(f"DP Size: {pc.parallel_context.dp_size}, PP Size: {pc.parallel_context.pp_size}, TP Size: {pc.parallel_context.grid.shape[0]}")
|
||||
output.append("") # Top spacing
|
||||
|
||||
for dp in range(pc.parallel_context.dp_size):
|
||||
output.append(f"DP {dp}:")
|
||||
output.append(f"{'':>8}{border}")
|
||||
|
||||
for tp in range(pc.parallel_context.grid.shape[0]):
|
||||
if tp == 0:
|
||||
output.append(f"{'TP':>7} {_create_row(pc.parallel_context.grid[tp, :, dp])}")
|
||||
else:
|
||||
output.append(f"{'':8}{border}")
|
||||
output.append(f"{'TP':>7} {_create_row(pc.parallel_context.grid[tp, :, dp])}")
|
||||
|
||||
output.append(f"{'':8}{border}")
|
||||
if pc.parallel_context.pp_size > 1:
|
||||
output.append(f"{'':>7}{_create_pp_line(row_width, pc.parallel_context.pp_size)}")
|
||||
|
||||
output.append("") # Spacing between DP blocks
|
||||
|
||||
output.append("") # Bottom spacing
|
||||
|
||||
output.append(f"=== Local Parallelism Configuration ===")
|
||||
output.append(pc.parallel_context.__str__())
|
||||
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.dp_group_ids]}")
|
||||
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.pp_group_ids]}")
|
||||
output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.tp_group_ids]}")
|
||||
|
||||
print("\n".join(output))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user