diff --git a/distributed_primtives.py b/distributed_primtives.py new file mode 100644 index 0000000..1678073 --- /dev/null +++ b/distributed_primtives.py @@ -0,0 +1,46 @@ +import os +import parallel_context as pc +import torch, torch.distributed as dist +import parallel_context as pc + +STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1" + +def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None): + global STEP + global VERBOSE + if operation == 'recv_forward': + if pc.parallel_context.is_pipeline_first_stage: return None + tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype) + src = pc.parallel_context.pp_prev_rank + elif operation == 'send_forward': + if pc.parallel_context.is_pipeline_last_stage: return + dest = pc.parallel_context.pp_next_rank + elif operation == 'recv_backward': + if pc.parallel_context.is_pipeline_last_stage: return None + tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype) + src = pc.parallel_context.pp_next_rank + elif operation == 'send_backward': + if pc.parallel_context.is_pipeline_first_stage: return + dest = pc.parallel_context.pp_prev_rank + is_send = operation.startswith('send') + peer_rank = dest if is_send else src + op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank) + if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pc.parallel_context.pp_rank} {'→' if is_send else '←'} {peer_rank} | STEP:{STEP} | RANK:{pc.parallel_context.pp_rank}", flush=True) + [req.wait() for req in dist.batch_isend_irecv([op])] + torch.cuda.synchronize() + if VERBOSE: STEP += 1 + return tensor if not is_send else None + +def bidirectional_communicate(operation, send_tensor, recv_shapes, dtype, device): + global STEP + global VERBOSE + is_fwd = (operation == 'send_fwd_recv_bwd') + if (is_fwd and pc.parallel_context.is_pipeline_last_stage) or (not is_fwd and pc.parallel_context.is_pipeline_first_stage): return None + peer_rank = pc.parallel_context.pp_next_rank if is_fwd else pc.parallel_context.pp_prev_rank + recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype) + reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)]) + if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pc.parallel_context.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pc.parallel_context.pp_rank} | "f"STEP {STEP=} | RANK:{pc.parallel_context.pp_rank}", flush=True) + [req.wait() for req in reqs] + torch.cuda.synchronize() + if VERBOSE: STEP += 1 + return recv_tensor \ No newline at end of file diff --git a/generate.py b/generate.py new file mode 100644 index 0000000..cd1b4ce --- /dev/null +++ b/generate.py @@ -0,0 +1,112 @@ +#VERBOSE=0 torchrun --nproc_per_node 3 generate.py +import os +import argparse +import torch, torch.distributed as dist +from transformers import AutoTokenizer + +from utils import set_all_seed +import parallel_context as pc +from parallel_context import setup_parallel_context +from pipeline_parallel import PipelineParallel +from distributed_primtives import communicate + +def run_one_inference_step(model, batch, device) -> torch.Tensor: + if pc.parallel_context.pp_world_size == 1: + return model.forward(batch, device) + + batch_size = batch["input_ids"].shape[0] + seq_len = batch["input_ids"].shape[1] + tensor_shapes = (batch_size, seq_len, model.config.hidden_size) + + # Preallocate memory for output logits. + logits = None + if pc.parallel_context.is_pipeline_last_stage: + logits = torch.empty((batch_size, seq_len, int(model.config.vocab_size)), dtype=torch.float32, device=device) + + recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32) + + batch["hidden_states"] = None if pc.parallel_context.is_pipeline_first_stage else recv_buffer + + output_tensor = model.forward(batch, device) + + # Send output to the next stage. + communicate(operation="send_forward", tensor=output_tensor) + + # Copy logits. + if pc.parallel_context.is_pipeline_last_stage: + logits = output_tensor + + dist.barrier() + + return logits + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max_tokens", type=int, default=32) + args = parser.parse_args() + + #TODO: support only PP + local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) + + dist.init_process_group(backend="nccl") + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + setup_parallel_context(local_rank, world_size) + set_all_seed(seed=42) + model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device) + + model.eval() + + # Tokenize the input + prompts = [ + "My name is", + "How old are you ?", + "What is your favorite color?", + ] + + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M-Instruct") + tokenizer.padding_side = "left" + tokenizer.pad_token = tokenizer.eos_token + + tokenized_prompts = tokenizer(prompts, return_tensors="pt", padding=True).to(device=device) + + + for _ in range(args.max_tokens): + + # Create the batch + seq_len = tokenized_prompts["input_ids"].shape[1] + position_index = torch.arange(seq_len).view(1, -1).to(device=device) + + batch_prompts = { + "input_ids": tokenized_prompts["input_ids"], + "target_ids": None, + "position_index": position_index, + "attn_mask": tokenized_prompts["attention_mask"].to(dtype=torch.bool), + "hidden_states": None, + } + + logits = run_one_inference_step(model, batch_prompts, device) + + # Sample new token + if pc.parallel_context.is_pipeline_last_stage: + assert logits is not None + next_token = torch.argmax(logits[:, -1], dim=-1) + tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1) + tokenized_prompts["attention_mask"] = torch.cat([tokenized_prompts["attention_mask"], torch.ones((tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.int64, device=device)], dim=-1) + else: + tokenized_prompts["input_ids"] = torch.zeros((tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), dtype=torch.int64, device=device) + tokenized_prompts["attention_mask"] = torch.zeros((tokenized_prompts["attention_mask"].shape[0], tokenized_prompts["attention_mask"].shape[1] + 1), dtype=torch.int64, device=device) + + dist.broadcast(tokenized_prompts["input_ids"], src=pc.parallel_context.pp_last_rank) + dist.broadcast(tokenized_prompts["attention_mask"], src=pc.parallel_context.pp_last_rank) + + # Get only the new generated tokens + if pc.parallel_context.is_pipeline_last_stage: + for i, prompt in enumerate(prompts): + tokenized_outputs = tokenized_prompts["input_ids"][i, tokenized_prompts["input_ids"].shape[1] - args.max_tokens:] + outputs = tokenizer.decode(tokenized_outputs) + + print(f"Input: {prompt}") + print(f"Output: {outputs}") + print("------") + \ No newline at end of file diff --git a/parallel_context.py b/parallel_context.py new file mode 100644 index 0000000..fa230c9 --- /dev/null +++ b/parallel_context.py @@ -0,0 +1,16 @@ +import torch.distributed as dist + +class ParallelContext: + def __init__(self, pp_rank, pp_world_size): + self.pp_rank, self.pp_world_size = pp_rank, pp_world_size + self.pp_group = dist.new_group(list(range(self.pp_world_size))) + self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else (self.pp_rank + 1) % self.pp_world_size + self.pp_prev_rank = None if self.pp_rank == 0 else (self.pp_rank - 1) % self.pp_world_size + self.is_pipeline_last_stage = self.pp_rank == self.pp_world_size - 1 + #TODO: refactor to handle TP and DP + self.pp_last_rank = self.pp_world_size - 1 + self.is_pipeline_first_stage = self.pp_rank == 0 + +def setup_parallel_context(local_rank, world_size): + global parallel_context + parallel_context = ParallelContext(pp_rank=local_rank, pp_world_size=world_size) \ No newline at end of file diff --git a/pipeline_parallel.py b/pipeline_parallel.py new file mode 100644 index 0000000..305b1f9 --- /dev/null +++ b/pipeline_parallel.py @@ -0,0 +1,104 @@ +from transformers import AutoConfig, AutoModelForCausalLM +import parallel_context as pc +from distributed_primtives import communicate, bidirectional_communicate +import torch, torch.nn as nn, torch.nn.functional as F + +class PipelineParallel(nn.Module): + def __init__(self, model_name): + super().__init__() + self.config = AutoConfig.from_pretrained(model_name) + base_model = AutoModelForCausalLM.from_pretrained(model_name, config=self.config) + layer_distribution = self.distribute_layers(self.config.num_hidden_layers) + self.embed_tokens = base_model.model.embed_tokens if pc.parallel_context.is_pipeline_first_stage else nn.Identity() + self.decoder_layers = nn.ModuleDict({str(i): base_model.model.layers[i] for i in layer_distribution}) + self.norm = base_model.model.norm if pc.parallel_context.is_pipeline_last_stage else nn.Identity() + self.lm_head = base_model.lm_head if pc.parallel_context.is_pipeline_last_stage else nn.Identity() + del base_model + + def distribute_layers(self, num_layers): + layers_per_gpu = [num_layers // pc.parallel_context.pp_world_size + (1 if i < num_layers % pc.parallel_context.pp_world_size else 0) for i in range(pc.parallel_context.pp_world_size)] + start_layer = sum(layers_per_gpu[:pc.parallel_context.pp_rank]) + return list(range(start_layer, start_layer + layers_per_gpu[pc.parallel_context.pp_rank])) + + def forward(self, batch, device): + x = batch["hidden_states"].to(device) if batch["hidden_states"] is not None else batch["input_ids"].to(device) + x = self.embed_tokens(x) + for layer in self.decoder_layers.values(): + x = layer(x, position_ids=batch["position_index"].to(device))[0] + x = self.norm(x) + return self.lm_head(x) + + def backward(self, input_tensor, output_tensor, output_tensor_grad): + if input_tensor is not None: input_tensor.retain_grad() + if output_tensor_grad is None: + output_tensor_grad = torch.ones_like(output_tensor, memory_format=torch.preserve_format) + torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False) + return input_tensor.grad if input_tensor is not None else None + +def pipeline_parallel_afab(model, data_loader, tensor_shapes, device): + logging_loss, input_tensors, output_tensors = 0.0, [], [] + + for _ in range(data_loader.num_local_micro_batches): # All forward passes + input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32) + batch = next(iter(data_loader)) + batch["hidden_states"] = input_tensor + output_tensor = model.forward(batch, device) + communicate(operation='send_forward', tensor=output_tensor) + if pc.parallel_context.is_pipeline_last_stage: + output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') + logging_loss += output_tensor.item() + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + + for _ in range(data_loader.num_local_micro_batches): # All backward passes + output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32) + input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) + input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) + communicate(operation='send_backward', tensor=input_tensor_grad) + + return logging_loss + +def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device): + num_warmup_microbatches = min(pc.parallel_context.pp_world_size - pc.parallel_context.pp_rank - 1, data_loader.num_local_micro_batches) + num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches + logging_loss, input_tensors, output_tensors = 0.0, [], [] + + def _forward_step(input_tensor): + batch = next(iter(data_loader)) + batch["hidden_states"] = input_tensor + output_tensor = model.forward(batch, device) + if pc.parallel_context.is_pipeline_last_stage: + output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') + nonlocal logging_loss + logging_loss += output_tensor.item() + return output_tensor + + for _ in range(num_warmup_microbatches): # Warmup forward passes + input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32) + output_tensor = _forward_step(input_tensor) + communicate(operation='send_forward', tensor=output_tensor) + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + + if num_microbatches_remaining > 0: + input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32) + + for i in range(num_microbatches_remaining): # 1F1B steady state + output_tensor = _forward_step(input_tensor) + output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, dtype=torch.float32, device=device) + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) + input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) + if i == num_microbatches_remaining - 1: # last iteration + input_tensor = None + communicate(operation='send_backward', tensor=input_tensor_grad) + else: + input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, dtype=torch.float32, device=device) + + for _ in range(num_warmup_microbatches): # Cooldown backward passes + input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) + output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32) + input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) + communicate(operation='send_backward', tensor=input_tensor_grad) + return logging_loss \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1e73be0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +torch +numpy +datasets +transformers==4.44.1 \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..88bae37 --- /dev/null +++ b/train.py @@ -0,0 +1,56 @@ +#VERBOSE=0 torchrun --nproc_per_node 3 train.py +import os +import torch, torch.distributed as dist +from torch.optim import AdamW +from torch.utils.data import DataLoader, DistributedSampler +from datasets import load_dataset +from transformers import AutoTokenizer + +import parallel_context as pc +from utils import set_all_seed +from parallel_context import setup_parallel_context +from pipeline_parallel import pipeline_parallel_1f1b, pipeline_parallel_afab, PipelineParallel + +class MicroBatchDataLoader(DataLoader): + def __init__(self, global_batch_size, micro_batch_size, data_parallel_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None): + self.global_batch_size, self.micro_batch_size, self.data_parallel_size, self.seq_length = global_batch_size, micro_batch_size, data_parallel_size, seq_length + self.local_batch_size = self.global_batch_size // self.data_parallel_size + self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size + self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + self.dataset = load_dataset(dataset_name, split=split) + if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) + dist.barrier() + self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"]) + super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=DistributedSampler(self.dataset, num_replicas=data_parallel_size, rank=0, shuffle=False), shuffle=False) + + def collate_batch(self, batch_data): + batch_input_ids = torch.stack([item['input_ids'] for item in batch_data]) + batch_size, seq_len = batch_input_ids.shape + return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None} + +if __name__ == "__main__": + os.environ["TOKENIZERS_PARALLELISM"] = "false" + local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) + + SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS = 10, 6, 2, 1e-4, 20, 1800 + dist.init_process_group(backend="nccl") + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + setup_parallel_context(local_rank, world_size) + + set_all_seed(seed=42) + model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device) + data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, 1, SEQ_LEN, "roneneldan/TinyStories", "HuggingFaceTB/SmolLM-360M-Instruct", num_samples=NUM_SAMPLES) + tensor_shapes = (SEQ_LEN, data_loader.micro_batch_size, model.config.hidden_size) + optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) + trained_tokens, step = 0, 0 + tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN + while trained_tokens < MAX_TOKENS: + optimizer.zero_grad() + loss = pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device) #loss = pipeline_parallel_afab(model, data_loader, tensor_shapes, device) + optimizer.step() + trained_tokens += tokens_per_step + step += 1 + if pc.parallel_context.is_pipeline_last_stage: + print(f"Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}") diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..352ae8c --- /dev/null +++ b/utils.py @@ -0,0 +1,6 @@ +import torch, random, numpy as np + +def set_all_seed(seed): + for module in [random, np.random]: module.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)