diff --git a/distributed/distributed_primtives.py b/distributed/distributed_primtives.py index df77c31..4451083 100644 --- a/distributed/distributed_primtives.py +++ b/distributed/distributed_primtives.py @@ -1,7 +1,7 @@ import os +from typing import List, Optional import distributed.process_group_manager as pgm import torch, torch.distributed as dist -import distributed.process_group_manager as pgm STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1" @@ -44,6 +44,53 @@ def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, devi torch.cuda.synchronize() if VERBOSE: STEP += 1 return recv_tensor + +class ContextComms: + def __init__(self, msg: str = ""): + global STEP + global VERBOSE + self._pending_operations: List[dist.P2POp] = [] + self._active_requests = None + self.rank = pgm.process_group_manager.cp_rank + self.world_size = pgm.process_group_manager.cp_world_size + self.send_rank = pgm.process_group_manager.cp_send_rank + self.recv_rank = pgm.process_group_manager.cp_recv_rank + if VERBOSE: print(f"RingComm ({msg}) | initialized | RANK:{self.rank} | "f"WORLD_SIZE:{self.world_size} | SEND_RANK:{self.send_rank} | "f"RECV_RANK:{self.recv_rank}", flush=True) + + def send_recv(self, tensor_to_send, recv_tensor=None): + if recv_tensor is None: + result_tensor = torch.zeros_like(tensor_to_send) + else: + result_tensor = recv_tensor + + send_operation = dist.P2POp(dist.isend, tensor_to_send, self.send_rank, group=pgm.process_group_manager.cp_group) + recv_operation = dist.P2POp(dist.irecv, result_tensor, self.recv_rank, group=pgm.process_group_manager.cp_group) + + self._pending_operations.extend([send_operation, recv_operation]) + + if VERBOSE: + print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:sending | TO:{self.send_rank} | TENSOR:{tensor_to_send}", flush=True) + print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:receiving | FROM:{self.recv_rank} | TENSOR:{result_tensor}", flush=True) + return result_tensor + + def commit(self): + if self._active_requests is not None: raise RuntimeError("Commit called twice") + self._active_requests = dist.batch_isend_irecv(self._pending_operations) + if VERBOSE: print(f"RingComm | commit | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:committed | NUM_OPS:{len(self._pending_operations) // 2}", flush=True) + + def wait(self): + if self._active_requests is None: raise RuntimeError("Wait called before commit") + for i, request in enumerate(self._active_requests): + request.wait() + if VERBOSE: + operation_type = "send" if i % 2 == 0 else "receive" + peer_rank = self.send_rank if operation_type == "send" else self.recv_rank + print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:completed_{operation_type} | "f"{'FROM' if operation_type == 'receive' else 'TO'}:{peer_rank}", flush=True) + torch.cuda.synchronize() + self._active_requests = None + self._pending_operations = [] + if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True) + def all_reduce_loss_across_dp_ranks(loss, device): reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) # Reduce the loss across all workers so that every rank has the updated loss value. diff --git a/distributed/process_group_manager.py b/distributed/process_group_manager.py index 738df8b..1cfff02 100644 --- a/distributed/process_group_manager.py +++ b/distributed/process_group_manager.py @@ -42,7 +42,9 @@ class ProcessGroupManager: self.cp_first_rank = self.cp_group_ids[0] self.cp_last_rank = self.cp_group_ids[-1] self.cp_world_size = dist.get_world_size(group=self.cp_group) - + self.cp_send_rank = self.cp_group_ids[(self.cp_rank + 1) % self.cp_size] + self.cp_recv_rank = self.cp_group_ids[(self.cp_rank - 1) % self.cp_size] + # Pipeline parallelism self.pp_first_rank = self.pp_group_ids[0] self.pp_last_rank = self.pp_group_ids[-1] diff --git a/model.py b/model.py index 8cad4e9..fc54862 100644 --- a/model.py +++ b/model.py @@ -70,9 +70,9 @@ class RotaryEmbedding(nn.Module): k_embed = (k * cos) + (self._rotate_half(k) * sin) return q_embed, k_embed -class MultiHeadAttention(nn.Module): +class Attention(nn.Module): def __init__(self, config, is_causal): - super(MultiHeadAttention, self).__init__() + super(Attention, self).__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -90,10 +90,10 @@ class MultiHeadAttention(nn.Module): self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.rotary = RotaryEmbedding(dim=self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta) - def forward(self, x, position_ids): - q = self.q_proj(x) - k = self.k_proj(x) - v = self.v_proj(x) + def forward(self, input_ids, position_ids): + q = self.q_proj(input_ids) + k = self.k_proj(input_ids) + v = self.v_proj(input_ids) batch, seq_len, _ = q.shape @@ -130,7 +130,7 @@ class DecoderLayer(nn.Module): def __init__(self, config, is_causal): super(DecoderLayer, self).__init__() - self.attention = MultiHeadAttention(config, is_causal) + self.attention = Attention(config, is_causal) self.norm_attn = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config) self.norm_mlp = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/parallel/context_parallel.py b/parallel/context_parallel.py index b4b5da8..6c1bd11 100644 --- a/parallel/context_parallel.py +++ b/parallel/context_parallel.py @@ -1,15 +1,258 @@ -import torch.distributed as dist +# Inspired by https://github.com/zhuzilin/ring-flash-attention +import torch import torch.nn as nn +import torch.nn.functional as F +from torch import distributed as dist +from typing import Any, Optional, Tuple +from distributed.distributed_primtives import ContextComms +from model import Attention import distributed.process_group_manager as pgm +import lovely_tensors as lt; lt.monkey_patch() +from utils import print class ContextParallel(nn.Module): def __init__(self, model, config): super().__init__() self.model = model + self.config = config + for name, module in self.model.named_modules(): + if isinstance(module, Attention) and not isinstance(module, RingAttention): + parent_name, child_name = name.rsplit('.', 1) + parent_module = self.model.get_submodule(parent_name) + setattr(parent_module, child_name, RingAttention(module)) + del module + def forward(self, *args, **kwargs): return self.model(*args, **kwargs) def backward(self, input_tensor, output_tensor, output_tensor_grad): - return self.model.backward(input_tensor, output_tensor, output_tensor_grad) \ No newline at end of file + return self.model.backward(input_tensor, output_tensor, output_tensor_grad) + +class RingAttention(nn.Module): + def __init__(self, original_mha): + super().__init__() + + self.hidden_size = original_mha.hidden_size + self.num_heads = original_mha.num_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = original_mha.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = original_mha.is_causal + + # Copy the weights from the original Attention + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + self.q_proj.weight.data.copy_(original_mha.q_proj.weight.data) + self.k_proj.weight.data.copy_(original_mha.k_proj.weight.data) + self.v_proj.weight.data.copy_(original_mha.v_proj.weight.data) + self.o_proj.weight.data.copy_(original_mha.o_proj.weight.data) + + self.rotary = original_mha.rotary + + def forward(self, input_ids, position_ids): + batch_size, seq_len, _ = input_ids.shape + + q = self.q_proj(input_ids) + k = self.k_proj(input_ids) + v = self.v_proj(input_ids) + + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary is not None: + cos, sin = self.rotary(v, position_ids) + q, k = self.rotary.apply_rotary_pos_emb(q, k, cos, sin) + + k = self._repeat_kv(k, self.num_key_value_groups) + v = self._repeat_kv(v, self.num_key_value_groups) + + # Apply ring attention + sm_scale = 1.0 / (q.size(-1) ** 0.5) + output = RingAttentionFunc.apply(q, k, v, sm_scale, self.is_causal) + + output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) + output = self.o_proj(output) + return output + + def _repeat_kv(self, x, n_rep): + batch, num_key_value_heads, seq_len, head_dim = x.shape + if n_rep == 1: + return x + x = x[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim) + return x.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim) + +class RingAttentionFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale, is_causal): + comm = ContextComms("comm") + #NOTE: Find a better to save these tensors without cloning + k_og = k.clone() + v_og = v.clone() + out, lse = None, None + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k = comm.send_recv(k) + next_v = comm.send_recv(v) + comm.commit() + + if not is_causal or step <= comm.rank: + block_out, block_lse = ring_attention_forward( + q, k, v, sm_scale, is_causal and step == 0 + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + ctx.save_for_backward(q, k_og, v_og, out, lse.squeeze(-1)) + ctx.sm_scale = sm_scale + ctx.is_causal = is_causal + return out + + @staticmethod + def backward(ctx, dout, *args): + + q, k, v, out, softmax_lse = ctx.saved_tensors + sm_scale = ctx.sm_scale + is_causal = ctx.is_causal + + kv_comm = ContextComms("kv_comm") + d_kv_comm = ContextComms("d_kv_comm") + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + next_dk, next_dv = None, None + next_k, next_v = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + + if step <= kv_comm.rank or not is_causal: + bwd_causal = is_causal and step == 0 + + block_dq_buffer, block_dk_buffer, block_dv_buffer = ring_attention_backward( + dout, q, k, v, out, softmax_lse, sm_scale, bwd_causal + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + dq += block_dq_buffer + d_kv_comm.wait() + dk = block_dk_buffer + next_dk + dv = block_dv_buffer + next_dv + elif step != 0: + d_kv_comm.wait() + dk = next_dk + dv = next_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk) + next_dv = d_kv_comm.send_recv(dv) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq, next_dk, next_dv, None, None + +def ring_attention_forward(q, k, v, sm_scale, is_causal): + batch_size, nheads, seqlen, d = q.shape + S = torch.matmul(q, k.transpose(-2, -1)) * sm_scale + + if is_causal: + causal_mask = torch.triu(torch.ones(seqlen, seqlen, device=q.device, dtype=torch.bool), diagonal=1) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, nheads, seqlen, seqlen) + S.masked_fill_(causal_mask, float('-inf')) + + # Online softmax + S_max = torch.max(S, dim=-1, keepdim=True)[0] + exp_S = torch.exp(S - S_max) + exp_sum = torch.sum(exp_S, dim=-1, keepdim=True) + log_sum_exp = torch.log(exp_sum) + S_max + P = exp_S / exp_sum + O = torch.matmul(P, v) + return O, log_sum_exp.squeeze(-1) + +def ring_attention_backward(dO, Q, K, V, O, softmax_lse, sm_scale, is_causal): + batch_size, nheads, seqlen, d = Q.shape + + # Recreate S and P from log_sum_exp + S = torch.matmul(Q, K.transpose(-2, -1)) * sm_scale + if is_causal: + causal_mask = torch.triu(torch.ones(seqlen, seqlen, device=Q.device, dtype=torch.bool), diagonal=1) + S = S.masked_fill(causal_mask.unsqueeze(0).unsqueeze(1), float('-inf')) + + P = torch.exp(S - softmax_lse.unsqueeze(-1)) + # Step 1: Compute dV + dV = torch.matmul(P.transpose(-2, -1), dO) + # Step 2: Compute dP + dP = torch.matmul(dO, V.transpose(-2, -1)) + # Step 3: Compute D + D = torch.sum(dO * O, dim=-1, keepdim=True) + # Step 4: Compute dS + dS = P * (dP - D) + # Apply causal mask to dS if is_causal is True + if is_causal: + dS = dS.masked_fill(causal_mask.unsqueeze(0).unsqueeze(1), 0) + # Step 5: Compute dQ + dQ = torch.matmul(dS, K) * sm_scale + # Step 6: Compute dK + dK = torch.matmul(dS.transpose(-2, -1), Q) * sm_scale + return dQ, dK, dV + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_: Optional[Any] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + + def _update(current_out, current_lse): + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + current_out = current_out - F.sigmoid(block_lse - current_lse) * (current_out - block_out) + current_lse = current_lse - F.logsigmoid(current_lse - block_lse) + return current_out, current_lse + + block_out = block_out.to(torch.float32) + block_lse = block_lse.unsqueeze(dim=-1) + + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + return block_out, block_lse + + if slice_ is not None: + out[slice_], lse[slice_] = _update(out[slice_], lse[slice_]) + else: + out, lse = _update(out, lse) + + return out, lse \ No newline at end of file diff --git a/train.py b/train.py index 30456e4..5ec3d24 100644 --- a/train.py +++ b/train.py @@ -11,7 +11,7 @@ import argparse import distributed.process_group_manager as pgm from distributed.distributed_primtives import all_reduce_gradients_across_dp_cp_ranks -from utils import set_all_seed, print, display_4D_parallelism_grid +from utils import set_all_seed, print from distributed.process_group_manager import setup_process_group_manager from parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel from parallel.data_parallel import DataParallel @@ -26,6 +26,8 @@ class MicroBatchDataLoader(DataLoader): self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size + self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) self.dataset = load_dataset(dataset_name, split=split) if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) @@ -42,8 +44,21 @@ class MicroBatchDataLoader(DataLoader): def collate_batch(self, batch_data): batch_input_ids = torch.stack([item['input_ids'] for item in batch_data]) batch_size, seq_len = batch_input_ids.shape - return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None} - + start_idx = pgm.process_group_manager.cp_rank * self.seq_length_per_gpu + end_idx = start_idx + self.seq_length_per_gpu + input_ids = batch_input_ids[:, start_idx:end_idx].contiguous() + target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous() + position_index = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() + local_attn_mask = torch.tril(torch.ones((self.seq_length_per_gpu, self.seq_length_per_gpu), dtype=torch.bool)) + attn_mask = local_attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() + + return { + "input_ids": input_ids, + "target_ids": target_ids, + "position_index": position_index, + "attn_mask": attn_mask, + "hidden_states": None + } def train_step(model, data_loader, device): total_loss = 0.0 @@ -55,11 +70,11 @@ def train_step(model, data_loader, device): position_ids = batch["position_index"].to(device) target_ids = batch["target_ids"].to(device) - outputs = model(input_ids=input_ids, position_ids=position_ids) - logits = outputs.logits + batch_size, seq_len = input_ids.shape - # Use your suggested cross_entropy calculation - loss = F.cross_entropy(logits.transpose(1, 2), target_ids, reduction='mean') + outputs = model(input_ids=input_ids, position_ids=position_ids) + + loss = F.cross_entropy(outputs.view(batch_size * seq_len, -1), target_ids.view(-1), reduction="mean") loss.backward() @@ -91,6 +106,8 @@ if __name__ == "__main__": SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42 + assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism" + backend = "gloo" if args.use_cpu else "nccl" if backend == "nccl": @@ -140,6 +157,9 @@ if __name__ == "__main__": model.load_state_dict(torch.load("smollm.pth")) + # if pgm.process_group_manager.tp_world_size > 1: + # model = TensorParallel(model, config).to(device) + if pgm.process_group_manager.cp_size > 1: model = ContextParallel(model, config).to(device) @@ -149,13 +169,10 @@ if __name__ == "__main__": if pgm.process_group_manager.dp_world_size > 1: model = DataParallel(model, config).to(device) - # if pgm.process_group_manager.tp_world_size > 1: - # model = TensorParallel(model, config).to(device) - model.train() data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES) - tensor_shapes = (SEQ_LEN, data_loader.micro_batch_size, config.hidden_size) + tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size) optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) trained_tokens, step = 0, 0