rename parallel_context to process_group_manager
This commit is contained in:
parent
9e9ef8236e
commit
b2e276d3b8
@ -1,6 +1,6 @@
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import parallel_context as pc
|
||||
import process_group_manager as pgm
|
||||
|
||||
class DataParallel(nn.Module):
|
||||
def __init__(self, model, config):
|
||||
@ -8,8 +8,8 @@ class DataParallel(nn.Module):
|
||||
#TODO: Interleave all_reduce
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.dp_world_size = pc.parallel_context.dp_world_size
|
||||
self.dp_rank = pc.parallel_context.dp_rank
|
||||
self.dp_world_size = pgm.process_group_manager.dp_world_size
|
||||
self.dp_rank = pgm.process_group_manager.dp_rank
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
@ -20,5 +20,5 @@ class DataParallel(nn.Module):
|
||||
def all_reduce_gradients(self):
|
||||
for param in self.model.parameters():
|
||||
if param.grad is not None:
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG, group=pc.parallel_context.dp_group)
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group)
|
||||
|
||||
@ -4,21 +4,22 @@ from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from datasets import load_dataset
|
||||
|
||||
import parallel_context as pc
|
||||
import process_group_manager as pgm
|
||||
|
||||
class MicroBatchDataLoader(DataLoader):
|
||||
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None):
|
||||
self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length
|
||||
self.local_batch_size = self.global_batch_size // pc.parallel_context.dp_world_size
|
||||
self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size
|
||||
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
|
||||
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
self.dataset = load_dataset(dataset_name, split=split)
|
||||
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
|
||||
dist.barrier()
|
||||
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"])
|
||||
|
||||
self.sampler = DistributedSampler(self.dataset, num_replicas=pc.parallel_context.dp_world_size, rank=pc.parallel_context.dp_rank, shuffle=False)
|
||||
self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False)
|
||||
|
||||
super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import parallel_context as pc
|
||||
import process_group_manager as pgm
|
||||
import torch, torch.distributed as dist
|
||||
import parallel_context as pc
|
||||
import process_group_manager as pgm
|
||||
|
||||
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
|
||||
|
||||
@ -9,23 +9,23 @@ def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
if operation == 'recv_forward':
|
||||
if pc.parallel_context.pp_is_first_stage: return None
|
||||
if pgm.process_group_manager.pp_is_first_stage: return None
|
||||
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
|
||||
src = pc.parallel_context.pp_prev_rank
|
||||
src = pgm.process_group_manager.pp_prev_rank
|
||||
elif operation == 'send_forward':
|
||||
if pc.parallel_context.pp_is_last_stage: return
|
||||
dest = pc.parallel_context.pp_next_rank
|
||||
if pgm.process_group_manager.pp_is_last_stage: return
|
||||
dest = pgm.process_group_manager.pp_next_rank
|
||||
elif operation == 'recv_backward':
|
||||
if pc.parallel_context.pp_is_last_stage: return None
|
||||
if pgm.process_group_manager.pp_is_last_stage: return None
|
||||
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
|
||||
src = pc.parallel_context.pp_next_rank
|
||||
src = pgm.process_group_manager.pp_next_rank
|
||||
elif operation == 'send_backward':
|
||||
if pc.parallel_context.pp_is_first_stage: return
|
||||
dest = pc.parallel_context.pp_prev_rank
|
||||
if pgm.process_group_manager.pp_is_first_stage: return
|
||||
dest = pgm.process_group_manager.pp_prev_rank
|
||||
is_send = operation.startswith('send')
|
||||
peer_rank = dest if is_send else src
|
||||
op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank)
|
||||
if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pc.parallel_context.pp_rank} {'→' if is_send else '←'} {peer_rank} | STEP:{STEP} | RANK:{pc.parallel_context.pp_rank}", flush=True)
|
||||
if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pgm.process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
|
||||
[req.wait() for req in dist.batch_isend_irecv([op])]
|
||||
torch.cuda.synchronize()
|
||||
if VERBOSE: STEP += 1
|
||||
@ -35,11 +35,11 @@ def bidirectional_communicate(operation, send_tensor, recv_shapes, dtype, device
|
||||
global STEP
|
||||
global VERBOSE
|
||||
is_fwd = (operation == 'send_fwd_recv_bwd')
|
||||
if (is_fwd and pc.parallel_context.pp_is_last_stage) or (not is_fwd and pc.parallel_context.pp_is_first_stage): return None
|
||||
peer_rank = pc.parallel_context.pp_next_rank if is_fwd else pc.parallel_context.pp_prev_rank
|
||||
if (is_fwd and pgm.process_group_manager.pp_is_last_stage) or (not is_fwd and pgm.process_group_manager.pp_is_first_stage): return None
|
||||
peer_rank = pgm.process_group_manager.pp_next_rank if is_fwd else pgm.process_group_manager.pp_prev_rank
|
||||
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
|
||||
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)])
|
||||
if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pc.parallel_context.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pc.parallel_context.pp_rank} | "f"STEP {STEP=} | RANK:{pc.parallel_context.pp_rank}", flush=True)
|
||||
if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pgm.process_group_manager.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pgm.process_group_manager.pp_rank} | "f"STEP {STEP=} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
|
||||
[req.wait() for req in reqs]
|
||||
torch.cuda.synchronize()
|
||||
if VERBOSE: STEP += 1
|
||||
|
||||
22
generate.py
22
generate.py
@ -5,13 +5,13 @@ import torch, torch.distributed as dist
|
||||
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM,AutoTokenizer
|
||||
|
||||
from utils import set_all_seed
|
||||
import parallel_context as pc
|
||||
from parallel_context import setup_parallel_context
|
||||
import process_group_manager as pgm
|
||||
from process_group_manager import setup_process_group_manager
|
||||
from pipeline_parallel import PipelineParallel
|
||||
from distributed_primtives import communicate
|
||||
|
||||
def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
|
||||
if pc.parallel_context.pp_world_size == 1:
|
||||
if pgm.process_group_manager.pp_world_size == 1:
|
||||
return model.forward(batch, device)
|
||||
|
||||
batch_size = batch["input_ids"].shape[0]
|
||||
@ -20,12 +20,12 @@ def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
|
||||
|
||||
# Preallocate memory for output logits.
|
||||
logits = None
|
||||
if pc.parallel_context.pp_is_last_stage:
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
logits = torch.empty((batch_size, seq_len, int(config.vocab_size)), dtype=torch.float32, device=device)
|
||||
|
||||
recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32)
|
||||
|
||||
batch["hidden_states"] = None if pc.parallel_context.pp_is_first_stage else recv_buffer
|
||||
batch["hidden_states"] = None if pgm.process_group_manager.pp_is_first_stage else recv_buffer
|
||||
|
||||
output_tensor = model.forward(batch, device)
|
||||
|
||||
@ -33,7 +33,7 @@ def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
|
||||
communicate(operation="send_forward", tensor=output_tensor)
|
||||
|
||||
# Copy logits.
|
||||
if pc.parallel_context.pp_is_last_stage:
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
logits = output_tensor
|
||||
|
||||
dist.barrier()
|
||||
@ -51,7 +51,7 @@ if __name__ == "__main__":
|
||||
dist.init_process_group(backend="nccl")
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
setup_parallel_context(tp_size=1, pp_size=args.pp_size, dp_size=1)
|
||||
setup_process_group_manager(tp_size=1, pp_size=args.pp_size, dp_size=1)
|
||||
set_all_seed(seed=42)
|
||||
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
@ -92,7 +92,7 @@ if __name__ == "__main__":
|
||||
logits = run_one_inference_step(model, batch_prompts, device, config)
|
||||
|
||||
# Sample new token
|
||||
if pc.parallel_context.pp_is_last_stage:
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
assert logits is not None
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1)
|
||||
@ -101,11 +101,11 @@ if __name__ == "__main__":
|
||||
tokenized_prompts["input_ids"] = torch.zeros((tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), dtype=torch.int64, device=device)
|
||||
tokenized_prompts["attention_mask"] = torch.zeros((tokenized_prompts["attention_mask"].shape[0], tokenized_prompts["attention_mask"].shape[1] + 1), dtype=torch.int64, device=device)
|
||||
|
||||
dist.broadcast(tokenized_prompts["input_ids"], src=pc.parallel_context.pp_last_rank)
|
||||
dist.broadcast(tokenized_prompts["attention_mask"], src=pc.parallel_context.pp_last_rank)
|
||||
dist.broadcast(tokenized_prompts["input_ids"], src=pgm.process_group_manager.pp_last_rank)
|
||||
dist.broadcast(tokenized_prompts["attention_mask"], src=pgm.process_group_manager.pp_last_rank)
|
||||
|
||||
# Get only the new generated tokens
|
||||
if pc.parallel_context.pp_is_last_stage:
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
for i, prompt in enumerate(prompts):
|
||||
tokenized_outputs = tokenized_prompts["input_ids"][i, tokenized_prompts["input_ids"].shape[1] - args.max_tokens:]
|
||||
outputs = tokenizer.decode(tokenized_outputs)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import parallel_context as pc
|
||||
import process_group_manager as pgm
|
||||
from distributed_primtives import communicate, bidirectional_communicate
|
||||
import torch, torch.nn as nn, torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
@ -6,22 +6,22 @@ import torch.distributed as dist
|
||||
def reduce_loss_across_dp_ranks(loss, device):
|
||||
# Reduce the loss across DP workers.
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pc.parallel_context.dp_group)
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group)
|
||||
return reduced_loss.item()
|
||||
|
||||
class PipelineParallel(nn.Module):
|
||||
def __init__(self, model, config):
|
||||
super().__init__()
|
||||
layer_distribution = self.distribute_layers(config.num_hidden_layers)
|
||||
self.embed_tokens = model.model.embed_tokens if pc.parallel_context.pp_is_first_stage else nn.Identity()
|
||||
self.embed_tokens = model.model.embed_tokens if pgm.process_group_manager.pp_is_first_stage else nn.Identity()
|
||||
self.decoder_layers = nn.ModuleDict({str(i): model.model.layers[i] for i in layer_distribution})
|
||||
self.norm = model.model.norm if pc.parallel_context.pp_is_last_stage else nn.Identity()
|
||||
self.lm_head = model.lm_head if pc.parallel_context.pp_is_last_stage else nn.Identity()
|
||||
self.norm = model.model.norm if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
|
||||
self.lm_head = model.lm_head if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
|
||||
|
||||
def distribute_layers(self, num_layers):
|
||||
layers_per_gpu = [num_layers // pc.parallel_context.pp_world_size + (1 if i < num_layers % pc.parallel_context.pp_world_size else 0) for i in range(pc.parallel_context.pp_world_size)]
|
||||
start_layer = sum(layers_per_gpu[:pc.parallel_context.pp_rank])
|
||||
return list(range(start_layer, start_layer + layers_per_gpu[pc.parallel_context.pp_rank]))
|
||||
layers_per_gpu = [num_layers // pgm.process_group_manager.pp_world_size + (1 if i < num_layers % pgm.process_group_manager.pp_world_size else 0) for i in range(pgm.process_group_manager.pp_world_size)]
|
||||
start_layer = sum(layers_per_gpu[:pgm.process_group_manager.pp_rank])
|
||||
return list(range(start_layer, start_layer + layers_per_gpu[pgm.process_group_manager.pp_rank]))
|
||||
|
||||
def forward(self, batch, device):
|
||||
x = batch["hidden_states"].to(device) if batch["hidden_states"] is not None else batch["input_ids"].to(device)
|
||||
@ -50,7 +50,7 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
||||
communicate(operation='send_forward', tensor=output_tensor)
|
||||
|
||||
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
|
||||
if pc.parallel_context.pp_is_last_stage and pc.parallel_context.global_rank == pc.parallel_context.tp_first_rank:
|
||||
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
|
||||
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
||||
logging_loss += output_tensor.item()
|
||||
|
||||
@ -67,7 +67,7 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
||||
return logging_loss
|
||||
|
||||
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
||||
num_warmup_microbatches = min(pc.parallel_context.pp_world_size - pc.parallel_context.pp_rank - 1, data_loader.num_local_micro_batches)
|
||||
num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.num_local_micro_batches)
|
||||
num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches
|
||||
logging_loss, input_tensors, output_tensors = 0.0, [], []
|
||||
|
||||
@ -76,7 +76,7 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
||||
batch["hidden_states"] = input_tensor
|
||||
output_tensor = model.forward(batch, device)
|
||||
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
|
||||
if pc.parallel_context.pp_is_last_stage and pc.parallel_context.global_rank == pc.parallel_context.tp_first_rank:
|
||||
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
|
||||
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
||||
nonlocal logging_loss
|
||||
logging_loss += output_tensor.item()
|
||||
|
||||
@ -2,7 +2,7 @@ import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
class ParallelContext:
|
||||
class ProcessGroupManager:
|
||||
def __init__(self, tp_size, pp_size, dp_size):
|
||||
self.global_rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
@ -48,6 +48,6 @@ class ParallelContext:
|
||||
def __str__(self):
|
||||
return f"DP({self.dp_size})-PP({self.pp_size})-TP({self.tp_size})-Rank({self.global_rank})"
|
||||
|
||||
def setup_parallel_context(tp_size, pp_size, dp_size):
|
||||
global parallel_context
|
||||
parallel_context = ParallelContext(tp_size, pp_size, dp_size)
|
||||
def setup_process_group_manager(tp_size, pp_size, dp_size):
|
||||
global process_group_manager
|
||||
process_group_manager = ProcessGroupManager(tp_size, pp_size, dp_size)
|
||||
24
train.py
24
train.py
@ -7,9 +7,9 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
import argparse
|
||||
|
||||
import parallel_context as pc
|
||||
import process_group_manager as pgm
|
||||
from utils import set_all_seed, display_parallelism_grid
|
||||
from parallel_context import setup_parallel_context
|
||||
from process_group_manager import setup_process_group_manager
|
||||
from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from data_parallel import DataParallel
|
||||
from dataset import MicroBatchDataLoader
|
||||
@ -55,9 +55,9 @@ if __name__ == "__main__":
|
||||
dist.init_process_group(rank=local_rank, world_size=world_size, backend="nccl", init_method=f"tcp://{host}:{port}")
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
setup_parallel_context(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size)
|
||||
setup_process_group_manager(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size)
|
||||
|
||||
if pc.parallel_context.global_rank == local_rank:
|
||||
if pgm.process_group_manager.global_rank == local_rank:
|
||||
display_parallelism_grid()
|
||||
|
||||
set_all_seed(seed=42)
|
||||
@ -65,10 +65,10 @@ if __name__ == "__main__":
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, config=config).to(device)
|
||||
|
||||
if pc.parallel_context.pp_world_size > 1:
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
model = PipelineParallel(model, config).to(device)
|
||||
|
||||
if pc.parallel_context.dp_world_size > 1:
|
||||
if pgm.process_group_manager.dp_world_size > 1:
|
||||
model = DataParallel(model, config).to(device)
|
||||
|
||||
model.train()
|
||||
@ -81,8 +81,10 @@ if __name__ == "__main__":
|
||||
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
||||
#TODO: find a way to setup reference model training
|
||||
#TODO: Add Context Parallelism
|
||||
#TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0)
|
||||
#TODO: Add activation checkpointing
|
||||
#TODO: add gradient accumulation
|
||||
|
||||
@ -91,12 +93,12 @@ if __name__ == "__main__":
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if pc.parallel_context.pp_world_size > 1:
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device)
|
||||
else:
|
||||
loss = train_step(model, data_loader, device)
|
||||
|
||||
if pc.parallel_context.dp_world_size > 1:
|
||||
if pgm.process_group_manager.dp_world_size > 1:
|
||||
# Average gradient across DP ranks
|
||||
model.all_reduce_gradients()
|
||||
|
||||
@ -105,7 +107,7 @@ if __name__ == "__main__":
|
||||
step += 1
|
||||
|
||||
#NOTE(fmom): change later to log on rank 0 (g00) everytime ?
|
||||
if pc.parallel_context.pp_is_last_stage and pc.parallel_context.global_rank == pc.parallel_context.tp_first_rank and pc.parallel_context.global_rank == pc.parallel_context.dp_first_rank:
|
||||
print(f"[rank {pc.parallel_context.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
|
||||
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank and pgm.process_group_manager.global_rank == pgm.process_group_manager.dp_first_rank:
|
||||
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
26
utils.py
26
utils.py
@ -1,5 +1,5 @@
|
||||
import torch, random, numpy as np
|
||||
import parallel_context as pc
|
||||
import process_group_manager as pgm
|
||||
|
||||
def set_all_seed(seed):
|
||||
for module in [random, np.random]: module.seed(seed)
|
||||
@ -21,37 +21,37 @@ def display_parallelism_grid():
|
||||
return " ".join("PP".center(box_width) for _ in range(pp_size))
|
||||
|
||||
output = []
|
||||
sample_row = _create_row(pc.parallel_context.grid[0, :, 0])
|
||||
sample_row = _create_row(pgm.process_group_manager.grid[0, :, 0])
|
||||
row_width = len(sample_row)
|
||||
border = _create_border(row_width)
|
||||
|
||||
output.append(f"=== Global Parallelism Configuration ===")
|
||||
output.append(f"DP Size: {pc.parallel_context.dp_size}, PP Size: {pc.parallel_context.pp_size}, TP Size: {pc.parallel_context.grid.shape[0]}")
|
||||
output.append(f"DP Size: {pgm.process_group_manager.dp_size}, PP Size: {pgm.process_group_manager.pp_size}, TP Size: {pgm.process_group_manager.grid.shape[0]}")
|
||||
output.append("") # Top spacing
|
||||
|
||||
for dp in range(pc.parallel_context.dp_size):
|
||||
for dp in range(pgm.process_group_manager.dp_size):
|
||||
output.append(f"DP {dp}:")
|
||||
output.append(f"{'':>8}{border}")
|
||||
|
||||
for tp in range(pc.parallel_context.grid.shape[0]):
|
||||
for tp in range(pgm.process_group_manager.grid.shape[0]):
|
||||
if tp == 0:
|
||||
output.append(f"{'TP':>7} {_create_row(pc.parallel_context.grid[tp, :, dp])}")
|
||||
output.append(f"{'TP':>7} {_create_row(pgm.process_group_manager.grid[tp, :, dp])}")
|
||||
else:
|
||||
output.append(f"{'':8}{border}")
|
||||
output.append(f"{'TP':>7} {_create_row(pc.parallel_context.grid[tp, :, dp])}")
|
||||
output.append(f"{'TP':>7} {_create_row(pgm.process_group_manager.grid[tp, :, dp])}")
|
||||
|
||||
output.append(f"{'':8}{border}")
|
||||
if pc.parallel_context.pp_size > 1:
|
||||
output.append(f"{'':>7}{_create_pp_line(row_width, pc.parallel_context.pp_size)}")
|
||||
if pgm.process_group_manager.pp_size > 1:
|
||||
output.append(f"{'':>7}{_create_pp_line(row_width, pgm.process_group_manager.pp_size)}")
|
||||
|
||||
output.append("") # Spacing between DP blocks
|
||||
|
||||
output.append("") # Bottom spacing
|
||||
|
||||
output.append(f"=== Local Parallelism Configuration ===")
|
||||
output.append(pc.parallel_context.__str__())
|
||||
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.dp_group_ids]}")
|
||||
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.pp_group_ids]}")
|
||||
output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.tp_group_ids]}")
|
||||
output.append(pgm.process_group_manager.__str__())
|
||||
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.dp_group_ids]}")
|
||||
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.pp_group_ids]}")
|
||||
output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.tp_group_ids]}")
|
||||
|
||||
print("\n".join(output))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user