From f6c9a39d17837cd5cf0b3a401bc635fbdf38c465 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 30 Oct 2024 15:23:29 +0000 Subject: [PATCH] fix spliting input twice for context parallel (done in dataloader) --- model.py | 4 ++-- src/parallel/context_parallel.py | 14 +------------- train.py | 4 +--- 3 files changed, 4 insertions(+), 18 deletions(-) diff --git a/model.py b/model.py index 934e768..e7bf4ee 100644 --- a/model.py +++ b/model.py @@ -2,7 +2,7 @@ import os import torch import torch.nn as nn import torch.nn.functional as F -from src.parallel.context_parallel import ring_attention, update_rope +from src.parallel.context_parallel import ring_attention, update_rope_for_context_parallel from flash_attn.flash_attn_interface import flash_attn_func from flash_attn.layers.rotary import apply_rotary_emb from flash_attn.ops.triton.layer_norm import layer_norm_fn @@ -161,7 +161,7 @@ class DecoderLayer(nn.Module): self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim] # For context parallelism, we split the input. We need to get the correct cos and sin for each split - self.cos, self.sin = update_rope(self.cos, self.sin) + self.cos, self.sin = update_rope_for_context_parallel(self.cos, self.sin) def forward(self, x, attention_mask = None, position_ids = None): #TODO: Use the default position_ids for RoPE during training. If we have time, work on generation diff --git a/src/parallel/context_parallel.py b/src/parallel/context_parallel.py index 035e8a0..10a5b89 100644 --- a/src/parallel/context_parallel.py +++ b/src/parallel/context_parallel.py @@ -181,19 +181,7 @@ def update_out_and_lse( return out, lse -def parallel_input(input_ids, target_ids): - cp_rank, cp_word_size = pgm.process_group_manager.cp_rank, pgm.process_group_manager.cp_world_size - batch_size, seq_length = input_ids.size() - assert seq_length % cp_word_size == 0, f"Input sequence length ({seq_length}) must be divisible by cp_world_size ({cp_word_size})" - size_per_partition = seq_length // cp_word_size - - # Calculate start and end indices for this rank - start_idx, end_idx = cp_rank * size_per_partition, (cp_rank + 1) * size_per_partition - local_input_ids = input_ids[:, start_idx:end_idx] - local_target_ids = target_ids[:, start_idx:end_idx] - return local_input_ids, local_target_ids - -def update_rope(cos, sin): +def update_rope_for_context_parallel(cos, sin): seq_len, _ = cos.size() cp_rank, cp_word_size = pgm.process_group_manager.cp_rank, pgm.process_group_manager.cp_world_size assert seq_len % cp_word_size == 0, f"Input sequence length ({seq_len}) must be divisible by cp_world_size ({cp_word_size})" diff --git a/train.py b/train.py index ac8d588..263a86a 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,6 @@ import os import json import time import argparse -from src.parallel.context_parallel import parallel_input import torch.nn.functional as F import torch, torch.distributed as dist from torch.optim import AdamW @@ -38,7 +37,6 @@ def train_step(model, data_loader, device): batch = next(data_loader) input_ids = batch["input_ids"].to(device) target_ids = batch["target_ids"].to(device) - input_ids, target_ids = parallel_input(input_ids, target_ids) # for context parallel, we need to split the input # disable gradient synchronization for all but the last micro-batch if requires_grad_sync: @@ -68,7 +66,7 @@ if __name__ == "__main__": os.environ["OMP_NUM_THREADS"] = config["environment"]["OMP_NUM_THREADS"] os.environ["TOKENIZERS_PARALLELISM"] = config["environment"]["TOKENIZERS_PARALLELISM"] - os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.float16! + os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.bfloat16! os.environ["DEVICE"] = "cpu" if config["distributed"]["use_cpu"] else "cuda" dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not config["distributed"]["use_cpu"] else torch.float32 # if GPU is not available or not supported, use torch.float32