add training and generate for pp
This commit is contained in:
parent
8a78ad1c82
commit
c36d415b47
46
distributed_primtives.py
Normal file
46
distributed_primtives.py
Normal file
@ -0,0 +1,46 @@
|
||||
import os
|
||||
import parallel_context as pc
|
||||
import torch, torch.distributed as dist
|
||||
import parallel_context as pc
|
||||
|
||||
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
|
||||
|
||||
def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
if operation == 'recv_forward':
|
||||
if pc.parallel_context.is_pipeline_first_stage: return None
|
||||
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
|
||||
src = pc.parallel_context.pp_prev_rank
|
||||
elif operation == 'send_forward':
|
||||
if pc.parallel_context.is_pipeline_last_stage: return
|
||||
dest = pc.parallel_context.pp_next_rank
|
||||
elif operation == 'recv_backward':
|
||||
if pc.parallel_context.is_pipeline_last_stage: return None
|
||||
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
|
||||
src = pc.parallel_context.pp_next_rank
|
||||
elif operation == 'send_backward':
|
||||
if pc.parallel_context.is_pipeline_first_stage: return
|
||||
dest = pc.parallel_context.pp_prev_rank
|
||||
is_send = operation.startswith('send')
|
||||
peer_rank = dest if is_send else src
|
||||
op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank)
|
||||
if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pc.parallel_context.pp_rank} {'→' if is_send else '←'} {peer_rank} | STEP:{STEP} | RANK:{pc.parallel_context.pp_rank}", flush=True)
|
||||
[req.wait() for req in dist.batch_isend_irecv([op])]
|
||||
torch.cuda.synchronize()
|
||||
if VERBOSE: STEP += 1
|
||||
return tensor if not is_send else None
|
||||
|
||||
def bidirectional_communicate(operation, send_tensor, recv_shapes, dtype, device):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
is_fwd = (operation == 'send_fwd_recv_bwd')
|
||||
if (is_fwd and pc.parallel_context.is_pipeline_last_stage) or (not is_fwd and pc.parallel_context.is_pipeline_first_stage): return None
|
||||
peer_rank = pc.parallel_context.pp_next_rank if is_fwd else pc.parallel_context.pp_prev_rank
|
||||
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
|
||||
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)])
|
||||
if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pc.parallel_context.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pc.parallel_context.pp_rank} | "f"STEP {STEP=} | RANK:{pc.parallel_context.pp_rank}", flush=True)
|
||||
[req.wait() for req in reqs]
|
||||
torch.cuda.synchronize()
|
||||
if VERBOSE: STEP += 1
|
||||
return recv_tensor
|
||||
112
generate.py
Normal file
112
generate.py
Normal file
@ -0,0 +1,112 @@
|
||||
#VERBOSE=0 torchrun --nproc_per_node 3 generate.py
|
||||
import os
|
||||
import argparse
|
||||
import torch, torch.distributed as dist
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from utils import set_all_seed
|
||||
import parallel_context as pc
|
||||
from parallel_context import setup_parallel_context
|
||||
from pipeline_parallel import PipelineParallel
|
||||
from distributed_primtives import communicate
|
||||
|
||||
def run_one_inference_step(model, batch, device) -> torch.Tensor:
|
||||
if pc.parallel_context.pp_world_size == 1:
|
||||
return model.forward(batch, device)
|
||||
|
||||
batch_size = batch["input_ids"].shape[0]
|
||||
seq_len = batch["input_ids"].shape[1]
|
||||
tensor_shapes = (batch_size, seq_len, model.config.hidden_size)
|
||||
|
||||
# Preallocate memory for output logits.
|
||||
logits = None
|
||||
if pc.parallel_context.is_pipeline_last_stage:
|
||||
logits = torch.empty((batch_size, seq_len, int(model.config.vocab_size)), dtype=torch.float32, device=device)
|
||||
|
||||
recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32)
|
||||
|
||||
batch["hidden_states"] = None if pc.parallel_context.is_pipeline_first_stage else recv_buffer
|
||||
|
||||
output_tensor = model.forward(batch, device)
|
||||
|
||||
# Send output to the next stage.
|
||||
communicate(operation="send_forward", tensor=output_tensor)
|
||||
|
||||
# Copy logits.
|
||||
if pc.parallel_context.is_pipeline_last_stage:
|
||||
logits = output_tensor
|
||||
|
||||
dist.barrier()
|
||||
|
||||
return logits
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max_tokens", type=int, default=32)
|
||||
args = parser.parse_args()
|
||||
|
||||
#TODO: support only PP
|
||||
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
setup_parallel_context(local_rank, world_size)
|
||||
set_all_seed(seed=42)
|
||||
model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Tokenize the input
|
||||
prompts = [
|
||||
"My name is",
|
||||
"How old are you ?",
|
||||
"What is your favorite color?",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M-Instruct")
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
tokenized_prompts = tokenizer(prompts, return_tensors="pt", padding=True).to(device=device)
|
||||
|
||||
|
||||
for _ in range(args.max_tokens):
|
||||
|
||||
# Create the batch
|
||||
seq_len = tokenized_prompts["input_ids"].shape[1]
|
||||
position_index = torch.arange(seq_len).view(1, -1).to(device=device)
|
||||
|
||||
batch_prompts = {
|
||||
"input_ids": tokenized_prompts["input_ids"],
|
||||
"target_ids": None,
|
||||
"position_index": position_index,
|
||||
"attn_mask": tokenized_prompts["attention_mask"].to(dtype=torch.bool),
|
||||
"hidden_states": None,
|
||||
}
|
||||
|
||||
logits = run_one_inference_step(model, batch_prompts, device)
|
||||
|
||||
# Sample new token
|
||||
if pc.parallel_context.is_pipeline_last_stage:
|
||||
assert logits is not None
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1)
|
||||
tokenized_prompts["attention_mask"] = torch.cat([tokenized_prompts["attention_mask"], torch.ones((tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.int64, device=device)], dim=-1)
|
||||
else:
|
||||
tokenized_prompts["input_ids"] = torch.zeros((tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), dtype=torch.int64, device=device)
|
||||
tokenized_prompts["attention_mask"] = torch.zeros((tokenized_prompts["attention_mask"].shape[0], tokenized_prompts["attention_mask"].shape[1] + 1), dtype=torch.int64, device=device)
|
||||
|
||||
dist.broadcast(tokenized_prompts["input_ids"], src=pc.parallel_context.pp_last_rank)
|
||||
dist.broadcast(tokenized_prompts["attention_mask"], src=pc.parallel_context.pp_last_rank)
|
||||
|
||||
# Get only the new generated tokens
|
||||
if pc.parallel_context.is_pipeline_last_stage:
|
||||
for i, prompt in enumerate(prompts):
|
||||
tokenized_outputs = tokenized_prompts["input_ids"][i, tokenized_prompts["input_ids"].shape[1] - args.max_tokens:]
|
||||
outputs = tokenizer.decode(tokenized_outputs)
|
||||
|
||||
print(f"Input: {prompt}")
|
||||
print(f"Output: {outputs}")
|
||||
print("------")
|
||||
|
||||
16
parallel_context.py
Normal file
16
parallel_context.py
Normal file
@ -0,0 +1,16 @@
|
||||
import torch.distributed as dist
|
||||
|
||||
class ParallelContext:
|
||||
def __init__(self, pp_rank, pp_world_size):
|
||||
self.pp_rank, self.pp_world_size = pp_rank, pp_world_size
|
||||
self.pp_group = dist.new_group(list(range(self.pp_world_size)))
|
||||
self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else (self.pp_rank + 1) % self.pp_world_size
|
||||
self.pp_prev_rank = None if self.pp_rank == 0 else (self.pp_rank - 1) % self.pp_world_size
|
||||
self.is_pipeline_last_stage = self.pp_rank == self.pp_world_size - 1
|
||||
#TODO: refactor to handle TP and DP
|
||||
self.pp_last_rank = self.pp_world_size - 1
|
||||
self.is_pipeline_first_stage = self.pp_rank == 0
|
||||
|
||||
def setup_parallel_context(local_rank, world_size):
|
||||
global parallel_context
|
||||
parallel_context = ParallelContext(pp_rank=local_rank, pp_world_size=world_size)
|
||||
104
pipeline_parallel.py
Normal file
104
pipeline_parallel.py
Normal file
@ -0,0 +1,104 @@
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
import parallel_context as pc
|
||||
from distributed_primtives import communicate, bidirectional_communicate
|
||||
import torch, torch.nn as nn, torch.nn.functional as F
|
||||
|
||||
class PipelineParallel(nn.Module):
|
||||
def __init__(self, model_name):
|
||||
super().__init__()
|
||||
self.config = AutoConfig.from_pretrained(model_name)
|
||||
base_model = AutoModelForCausalLM.from_pretrained(model_name, config=self.config)
|
||||
layer_distribution = self.distribute_layers(self.config.num_hidden_layers)
|
||||
self.embed_tokens = base_model.model.embed_tokens if pc.parallel_context.is_pipeline_first_stage else nn.Identity()
|
||||
self.decoder_layers = nn.ModuleDict({str(i): base_model.model.layers[i] for i in layer_distribution})
|
||||
self.norm = base_model.model.norm if pc.parallel_context.is_pipeline_last_stage else nn.Identity()
|
||||
self.lm_head = base_model.lm_head if pc.parallel_context.is_pipeline_last_stage else nn.Identity()
|
||||
del base_model
|
||||
|
||||
def distribute_layers(self, num_layers):
|
||||
layers_per_gpu = [num_layers // pc.parallel_context.pp_world_size + (1 if i < num_layers % pc.parallel_context.pp_world_size else 0) for i in range(pc.parallel_context.pp_world_size)]
|
||||
start_layer = sum(layers_per_gpu[:pc.parallel_context.pp_rank])
|
||||
return list(range(start_layer, start_layer + layers_per_gpu[pc.parallel_context.pp_rank]))
|
||||
|
||||
def forward(self, batch, device):
|
||||
x = batch["hidden_states"].to(device) if batch["hidden_states"] is not None else batch["input_ids"].to(device)
|
||||
x = self.embed_tokens(x)
|
||||
for layer in self.decoder_layers.values():
|
||||
x = layer(x, position_ids=batch["position_index"].to(device))[0]
|
||||
x = self.norm(x)
|
||||
return self.lm_head(x)
|
||||
|
||||
def backward(self, input_tensor, output_tensor, output_tensor_grad):
|
||||
if input_tensor is not None: input_tensor.retain_grad()
|
||||
if output_tensor_grad is None:
|
||||
output_tensor_grad = torch.ones_like(output_tensor, memory_format=torch.preserve_format)
|
||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False)
|
||||
return input_tensor.grad if input_tensor is not None else None
|
||||
|
||||
def pipeline_parallel_afab(model, data_loader, tensor_shapes, device):
|
||||
logging_loss, input_tensors, output_tensors = 0.0, [], []
|
||||
|
||||
for _ in range(data_loader.num_local_micro_batches): # All forward passes
|
||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32)
|
||||
batch = next(iter(data_loader))
|
||||
batch["hidden_states"] = input_tensor
|
||||
output_tensor = model.forward(batch, device)
|
||||
communicate(operation='send_forward', tensor=output_tensor)
|
||||
if pc.parallel_context.is_pipeline_last_stage:
|
||||
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
||||
logging_loss += output_tensor.item()
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
for _ in range(data_loader.num_local_micro_batches): # All backward passes
|
||||
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32)
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad)
|
||||
|
||||
return logging_loss
|
||||
|
||||
def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device):
|
||||
num_warmup_microbatches = min(pc.parallel_context.pp_world_size - pc.parallel_context.pp_rank - 1, data_loader.num_local_micro_batches)
|
||||
num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches
|
||||
logging_loss, input_tensors, output_tensors = 0.0, [], []
|
||||
|
||||
def _forward_step(input_tensor):
|
||||
batch = next(iter(data_loader))
|
||||
batch["hidden_states"] = input_tensor
|
||||
output_tensor = model.forward(batch, device)
|
||||
if pc.parallel_context.is_pipeline_last_stage:
|
||||
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
||||
nonlocal logging_loss
|
||||
logging_loss += output_tensor.item()
|
||||
return output_tensor
|
||||
|
||||
for _ in range(num_warmup_microbatches): # Warmup forward passes
|
||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32)
|
||||
output_tensor = _forward_step(input_tensor)
|
||||
communicate(operation='send_forward', tensor=output_tensor)
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
if num_microbatches_remaining > 0:
|
||||
input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32)
|
||||
|
||||
for i in range(num_microbatches_remaining): # 1F1B steady state
|
||||
output_tensor = _forward_step(input_tensor)
|
||||
output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, dtype=torch.float32, device=device)
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
if i == num_microbatches_remaining - 1: # last iteration
|
||||
input_tensor = None
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad)
|
||||
else:
|
||||
input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, dtype=torch.float32, device=device)
|
||||
|
||||
for _ in range(num_warmup_microbatches): # Cooldown backward passes
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad)
|
||||
return logging_loss
|
||||
4
requirements.txt
Normal file
4
requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
torch
|
||||
numpy
|
||||
datasets
|
||||
transformers==4.44.1
|
||||
56
train.py
Normal file
56
train.py
Normal file
@ -0,0 +1,56 @@
|
||||
#VERBOSE=0 torchrun --nproc_per_node 3 train.py
|
||||
import os
|
||||
import torch, torch.distributed as dist
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import parallel_context as pc
|
||||
from utils import set_all_seed
|
||||
from parallel_context import setup_parallel_context
|
||||
from pipeline_parallel import pipeline_parallel_1f1b, pipeline_parallel_afab, PipelineParallel
|
||||
|
||||
class MicroBatchDataLoader(DataLoader):
|
||||
def __init__(self, global_batch_size, micro_batch_size, data_parallel_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None):
|
||||
self.global_batch_size, self.micro_batch_size, self.data_parallel_size, self.seq_length = global_batch_size, micro_batch_size, data_parallel_size, seq_length
|
||||
self.local_batch_size = self.global_batch_size // self.data_parallel_size
|
||||
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
|
||||
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
self.dataset = load_dataset(dataset_name, split=split)
|
||||
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
|
||||
dist.barrier()
|
||||
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"])
|
||||
super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=DistributedSampler(self.dataset, num_replicas=data_parallel_size, rank=0, shuffle=False), shuffle=False)
|
||||
|
||||
def collate_batch(self, batch_data):
|
||||
batch_input_ids = torch.stack([item['input_ids'] for item in batch_data])
|
||||
batch_size, seq_len = batch_input_ids.shape
|
||||
return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None}
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
||||
|
||||
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS = 10, 6, 2, 1e-4, 20, 1800
|
||||
dist.init_process_group(backend="nccl")
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
setup_parallel_context(local_rank, world_size)
|
||||
|
||||
set_all_seed(seed=42)
|
||||
model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device)
|
||||
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, 1, SEQ_LEN, "roneneldan/TinyStories", "HuggingFaceTB/SmolLM-360M-Instruct", num_samples=NUM_SAMPLES)
|
||||
tensor_shapes = (SEQ_LEN, data_loader.micro_batch_size, model.config.hidden_size)
|
||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
trained_tokens, step = 0, 0
|
||||
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN
|
||||
while trained_tokens < MAX_TOKENS:
|
||||
optimizer.zero_grad()
|
||||
loss = pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device) #loss = pipeline_parallel_afab(model, data_loader, tensor_shapes, device)
|
||||
optimizer.step()
|
||||
trained_tokens += tokens_per_step
|
||||
step += 1
|
||||
if pc.parallel_context.is_pipeline_last_stage:
|
||||
print(f"Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
|
||||
Loading…
Reference in New Issue
Block a user