enhance parallel context to handle 3D

This commit is contained in:
ferdinand.mom 2024-09-23 10:28:01 +00:00
parent c36d415b47
commit bce75fd508
5 changed files with 144 additions and 34 deletions

View File

@ -9,18 +9,18 @@ def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None):
global STEP global STEP
global VERBOSE global VERBOSE
if operation == 'recv_forward': if operation == 'recv_forward':
if pc.parallel_context.is_pipeline_first_stage: return None if pc.parallel_context.pp_is_first_stage: return None
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype) tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
src = pc.parallel_context.pp_prev_rank src = pc.parallel_context.pp_prev_rank
elif operation == 'send_forward': elif operation == 'send_forward':
if pc.parallel_context.is_pipeline_last_stage: return if pc.parallel_context.pp_is_last_stage: return
dest = pc.parallel_context.pp_next_rank dest = pc.parallel_context.pp_next_rank
elif operation == 'recv_backward': elif operation == 'recv_backward':
if pc.parallel_context.is_pipeline_last_stage: return None if pc.parallel_context.pp_is_last_stage: return None
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype) tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
src = pc.parallel_context.pp_next_rank src = pc.parallel_context.pp_next_rank
elif operation == 'send_backward': elif operation == 'send_backward':
if pc.parallel_context.is_pipeline_first_stage: return if pc.parallel_context.pp_is_first_stage: return
dest = pc.parallel_context.pp_prev_rank dest = pc.parallel_context.pp_prev_rank
is_send = operation.startswith('send') is_send = operation.startswith('send')
peer_rank = dest if is_send else src peer_rank = dest if is_send else src
@ -35,7 +35,7 @@ def bidirectional_communicate(operation, send_tensor, recv_shapes, dtype, device
global STEP global STEP
global VERBOSE global VERBOSE
is_fwd = (operation == 'send_fwd_recv_bwd') is_fwd = (operation == 'send_fwd_recv_bwd')
if (is_fwd and pc.parallel_context.is_pipeline_last_stage) or (not is_fwd and pc.parallel_context.is_pipeline_first_stage): return None if (is_fwd and pc.parallel_context.pp_is_last_stage) or (not is_fwd and pc.parallel_context.pp_is_first_stage): return None
peer_rank = pc.parallel_context.pp_next_rank if is_fwd else pc.parallel_context.pp_prev_rank peer_rank = pc.parallel_context.pp_next_rank if is_fwd else pc.parallel_context.pp_prev_rank
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype) recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)]) reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)])

View File

@ -1,4 +1,4 @@
#VERBOSE=0 torchrun --nproc_per_node 3 generate.py #VERBOSE=0 torchrun --nproc_per_node 3 generate.py --pp_size 3
import os import os
import argparse import argparse
import torch, torch.distributed as dist import torch, torch.distributed as dist
@ -20,12 +20,12 @@ def run_one_inference_step(model, batch, device) -> torch.Tensor:
# Preallocate memory for output logits. # Preallocate memory for output logits.
logits = None logits = None
if pc.parallel_context.is_pipeline_last_stage: if pc.parallel_context.pp_is_last_stage:
logits = torch.empty((batch_size, seq_len, int(model.config.vocab_size)), dtype=torch.float32, device=device) logits = torch.empty((batch_size, seq_len, int(model.config.vocab_size)), dtype=torch.float32, device=device)
recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32) recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32)
batch["hidden_states"] = None if pc.parallel_context.is_pipeline_first_stage else recv_buffer batch["hidden_states"] = None if pc.parallel_context.pp_is_first_stage else recv_buffer
output_tensor = model.forward(batch, device) output_tensor = model.forward(batch, device)
@ -33,7 +33,7 @@ def run_one_inference_step(model, batch, device) -> torch.Tensor:
communicate(operation="send_forward", tensor=output_tensor) communicate(operation="send_forward", tensor=output_tensor)
# Copy logits. # Copy logits.
if pc.parallel_context.is_pipeline_last_stage: if pc.parallel_context.pp_is_last_stage:
logits = output_tensor logits = output_tensor
dist.barrier() dist.barrier()
@ -42,16 +42,16 @@ def run_one_inference_step(model, batch, device) -> torch.Tensor:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--pp_size", type=int, default=1)
parser.add_argument("--max_tokens", type=int, default=32) parser.add_argument("--max_tokens", type=int, default=32)
args = parser.parse_args() args = parser.parse_args()
#TODO: support only PP
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
setup_parallel_context(local_rank, world_size) setup_parallel_context(tp_size=1, pp_size=args.pp_size, dp_size=1)
set_all_seed(seed=42) set_all_seed(seed=42)
model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device) model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device)
@ -60,8 +60,8 @@ if __name__ == "__main__":
# Tokenize the input # Tokenize the input
prompts = [ prompts = [
"My name is", "My name is",
"How old are you ?", # "How old are you ?",
"What is your favorite color?", # "What is your favorite color?",
] ]
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M-Instruct") tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M-Instruct")
@ -88,7 +88,7 @@ if __name__ == "__main__":
logits = run_one_inference_step(model, batch_prompts, device) logits = run_one_inference_step(model, batch_prompts, device)
# Sample new token # Sample new token
if pc.parallel_context.is_pipeline_last_stage: if pc.parallel_context.pp_is_last_stage:
assert logits is not None assert logits is not None
next_token = torch.argmax(logits[:, -1], dim=-1) next_token = torch.argmax(logits[:, -1], dim=-1)
tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1) tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1)
@ -101,7 +101,7 @@ if __name__ == "__main__":
dist.broadcast(tokenized_prompts["attention_mask"], src=pc.parallel_context.pp_last_rank) dist.broadcast(tokenized_prompts["attention_mask"], src=pc.parallel_context.pp_last_rank)
# Get only the new generated tokens # Get only the new generated tokens
if pc.parallel_context.is_pipeline_last_stage: if pc.parallel_context.pp_is_last_stage:
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
tokenized_outputs = tokenized_prompts["input_ids"][i, tokenized_prompts["input_ids"].shape[1] - args.max_tokens:] tokenized_outputs = tokenized_prompts["input_ids"][i, tokenized_prompts["input_ids"].shape[1] - args.max_tokens:]
outputs = tokenizer.decode(tokenized_outputs) outputs = tokenizer.decode(tokenized_outputs)

View File

@ -1,16 +1,113 @@
import os
import torch
import torch.distributed as dist import torch.distributed as dist
class ParallelContext: class ParallelContext:
def __init__(self, pp_rank, pp_world_size): def __init__(self, tp_size, pp_size, dp_size):
self.pp_rank, self.pp_world_size = pp_rank, pp_world_size self.global_rank = dist.get_rank()
self.pp_group = dist.new_group(list(range(self.pp_world_size))) self.world_size = dist.get_world_size()
self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else (self.pp_rank + 1) % self.pp_world_size self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size))
self.pp_prev_rank = None if self.pp_rank == 0 else (self.pp_rank - 1) % self.pp_world_size
self.is_pipeline_last_stage = self.pp_rank == self.pp_world_size - 1
#TODO: refactor to handle TP and DP
self.pp_last_rank = self.pp_world_size - 1
self.is_pipeline_first_stage = self.pp_rank == 0
def setup_parallel_context(local_rank, world_size): self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dp_size
assert self.world_size == self.tp_size * self.pp_size * self.dp_size, f"World size ({self.world_size}) != TP ({self.tp_size}) * PP ({self.pp_size}) * DP ({self.dp_size})"
self.grid = torch.arange(self.world_size).view(self.pp_size, self.dp_size, self.tp_size).permute(2, 0, 1)
# Find the position of the current process in the grid
self.tp_rank, self.pp_rank, self.dp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist()
# Process group creation
self.tp_group_ids = self.grid[:, self.pp_rank, self.dp_rank].tolist()
self.pp_group_ids = self.grid[self.tp_rank, :, self.dp_rank].tolist()
self.dp_group_ids = self.grid[self.tp_rank, self.pp_rank, :].tolist()
self.tp_pp_group_ids = self.grid[..., self.dp_rank].tolist()
self.tp_group = dist.new_group(self.tp_group_ids)
self.pp_group = dist.new_group(self.pp_group_ids)
self.dp_group = dist.new_group(self.dp_group_ids)
self.tp_pp_group = dist.new_subgroups_by_enumeration(self.tp_pp_group_ids)[0]
# Tensor parallelism
self.tp_first_rank = self.tp_group_ids[0]
self.tp_last_rank = self.tp_group_ids[-1]
self.tp_is_first_stage = self.tp_rank == 0
self.tp_is_last_stage = self.tp_rank == self.tp_size - 1
self.tp_world_size = dist.get_world_size(group=self.tp_group)
# Pipeline parallelism
self.pp_first_rank = self.pp_group_ids[0]
self.pp_last_rank = self.pp_group_ids[-1]
self.pp_is_first_stage = self.pp_rank == 0
self.pp_is_last_stage = self.pp_rank == self.pp_size - 1
self.pp_next_rank = None if self.pp_rank == self.pp_size - 1 else int(self.grid[self.tp_rank, self.pp_rank + 1, self.dp_rank].item())
self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.tp_rank, self.pp_rank - 1, self.dp_rank].item())
self.pp_world_size = dist.get_world_size(group=self.pp_group)
# Data parallelism
self.dp_first_rank = self.dp_group_ids[0]
self.dp_last_rank = self.dp_group_ids[-1]
self.dp_is_first_stage = self.dp_rank == 0
self.dp_is_last_stage = self.dp_rank == self.dp_size - 1
self.dp_world_size = dist.get_world_size(group=self.dp_group)
# Tensor parallelism and pipeline parallelism
self.tp_pp_world_size = dist.get_world_size(group=self.tp_pp_group)
def __str__(self):
return f"DP({self.dp_size})-PP({self.pp_size})-TP({self.tp_size})-Rank({self.global_rank})"
def display_parallelism_grid(self):
def _create_box(content):
return f" {content:^3} "
def _create_row(row):
return "|" + "|".join(_create_box(f"g{num:02d}") for num in row) + "|"
def _create_border(width):
return "+" + "-" * (width - 2) + "+"
def _create_pp_line(width, pp_size):
box_width = (width - pp_size + 1) // pp_size
return " ".join("PP".center(box_width) for _ in range(pp_size))
output = []
sample_row = _create_row(self.grid[0, :, 0])
row_width = len(sample_row)
border = _create_border(row_width)
output.append(f"=== Global Parallelism Configuration ===")
output.append(f"DP Size: {self.dp_size}, PP Size: {self.pp_size}, TP Size: {self.grid.shape[0]}")
output.append("") # Top spacing
for dp in range(self.dp_size):
output.append(f"DP {dp}:")
output.append(f"{'':>8}{border}")
for tp in range(self.grid.shape[0]):
if tp == 0:
output.append(f"{'TP':>7} {_create_row(self.grid[tp, :, dp])}")
else:
output.append(f"{'':8}{border}")
output.append(f"{'TP':>7} {_create_row(self.grid[tp, :, dp])}")
output.append(f"{'':8}{border}")
if self.pp_size > 1:
output.append(f"{'':>7}{_create_pp_line(row_width, self.pp_size)}")
output.append("") # Spacing between DP blocks
output.append("") # Bottom spacing
output.append(f"=== Local Parallelism Configuration ===")
output.append(self.__str__())
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in self.dp_group_ids]}")
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in self.pp_group_ids]}")
output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in self.tp_group_ids]}")
output.append(f"TP-PP Group IDs: {[['g{:02d}'.format(id) for id in subgroup] for subgroup in self.tp_pp_group_ids]}")
print("\n".join(output))
def setup_parallel_context(tp_size, pp_size, dp_size):
global parallel_context global parallel_context
parallel_context = ParallelContext(pp_rank=local_rank, pp_world_size=world_size) parallel_context = ParallelContext(tp_size, pp_size, dp_size)

View File

@ -9,10 +9,10 @@ class PipelineParallel(nn.Module):
self.config = AutoConfig.from_pretrained(model_name) self.config = AutoConfig.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(model_name, config=self.config) base_model = AutoModelForCausalLM.from_pretrained(model_name, config=self.config)
layer_distribution = self.distribute_layers(self.config.num_hidden_layers) layer_distribution = self.distribute_layers(self.config.num_hidden_layers)
self.embed_tokens = base_model.model.embed_tokens if pc.parallel_context.is_pipeline_first_stage else nn.Identity() self.embed_tokens = base_model.model.embed_tokens if pc.parallel_context.pp_is_first_stage else nn.Identity()
self.decoder_layers = nn.ModuleDict({str(i): base_model.model.layers[i] for i in layer_distribution}) self.decoder_layers = nn.ModuleDict({str(i): base_model.model.layers[i] for i in layer_distribution})
self.norm = base_model.model.norm if pc.parallel_context.is_pipeline_last_stage else nn.Identity() self.norm = base_model.model.norm if pc.parallel_context.pp_is_last_stage else nn.Identity()
self.lm_head = base_model.lm_head if pc.parallel_context.is_pipeline_last_stage else nn.Identity() self.lm_head = base_model.lm_head if pc.parallel_context.pp_is_last_stage else nn.Identity()
del base_model del base_model
def distribute_layers(self, num_layers): def distribute_layers(self, num_layers):
@ -44,7 +44,7 @@ def pipeline_parallel_afab(model, data_loader, tensor_shapes, device):
batch["hidden_states"] = input_tensor batch["hidden_states"] = input_tensor
output_tensor = model.forward(batch, device) output_tensor = model.forward(batch, device)
communicate(operation='send_forward', tensor=output_tensor) communicate(operation='send_forward', tensor=output_tensor)
if pc.parallel_context.is_pipeline_last_stage: if pc.parallel_context.pp_is_last_stage:
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
logging_loss += output_tensor.item() logging_loss += output_tensor.item()
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
@ -67,7 +67,7 @@ def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device):
batch = next(iter(data_loader)) batch = next(iter(data_loader))
batch["hidden_states"] = input_tensor batch["hidden_states"] = input_tensor
output_tensor = model.forward(batch, device) output_tensor = model.forward(batch, device)
if pc.parallel_context.is_pipeline_last_stage: if pc.parallel_context.pp_is_last_stage:
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
nonlocal logging_loss nonlocal logging_loss
logging_loss += output_tensor.item() logging_loss += output_tensor.item()

View File

@ -1,10 +1,11 @@
#VERBOSE=0 torchrun --nproc_per_node 3 train.py #VERBOSE=0 torchrun --nproc_per_node 3 train.py --pp_size 3
import os import os
import torch, torch.distributed as dist import torch, torch.distributed as dist
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
import argparse
import parallel_context as pc import parallel_context as pc
from utils import set_all_seed from utils import set_all_seed
@ -30,14 +31,26 @@ class MicroBatchDataLoader(DataLoader):
return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None} return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None}
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tp_size", type=int, default=1)
parser.add_argument("--pp_size", type=int, default=1)
parser.add_argument("--dp_size", type=int, default=1)
args = parser.parse_args()
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS = 10, 6, 2, 1e-4, 20, 1800 SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS = 10, 6, 2, 1e-4, 20, 1800
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
setup_parallel_context(local_rank, world_size) setup_parallel_context(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size)
if pc.parallel_context.global_rank == 0:
pc.parallel_context.display_parallelism_grid()
set_all_seed(seed=42) set_all_seed(seed=42)
model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device) model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device)
@ -52,5 +65,5 @@ if __name__ == "__main__":
optimizer.step() optimizer.step()
trained_tokens += tokens_per_step trained_tokens += tokens_per_step
step += 1 step += 1
if pc.parallel_context.is_pipeline_last_stage: if pc.parallel_context.pp_is_last_stage:
print(f"Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}") print(f"Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")