fix spliting input twice for context parallel (done in dataloader)

This commit is contained in:
ferdinand.mom 2024-10-30 15:23:29 +00:00
parent 363dbd5c05
commit f6c9a39d17
3 changed files with 4 additions and 18 deletions

View File

@ -2,7 +2,7 @@ import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.parallel.context_parallel import ring_attention, update_rope
from src.parallel.context_parallel import ring_attention, update_rope_for_context_parallel
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.layers.rotary import apply_rotary_emb
from flash_attn.ops.triton.layer_norm import layer_norm_fn
@ -161,7 +161,7 @@ class DecoderLayer(nn.Module):
self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim]
# For context parallelism, we split the input. We need to get the correct cos and sin for each split
self.cos, self.sin = update_rope(self.cos, self.sin)
self.cos, self.sin = update_rope_for_context_parallel(self.cos, self.sin)
def forward(self, x, attention_mask = None, position_ids = None):
#TODO: Use the default position_ids for RoPE during training. If we have time, work on generation

View File

@ -181,19 +181,7 @@ def update_out_and_lse(
return out, lse
def parallel_input(input_ids, target_ids):
cp_rank, cp_word_size = pgm.process_group_manager.cp_rank, pgm.process_group_manager.cp_world_size
batch_size, seq_length = input_ids.size()
assert seq_length % cp_word_size == 0, f"Input sequence length ({seq_length}) must be divisible by cp_world_size ({cp_word_size})"
size_per_partition = seq_length // cp_word_size
# Calculate start and end indices for this rank
start_idx, end_idx = cp_rank * size_per_partition, (cp_rank + 1) * size_per_partition
local_input_ids = input_ids[:, start_idx:end_idx]
local_target_ids = target_ids[:, start_idx:end_idx]
return local_input_ids, local_target_ids
def update_rope(cos, sin):
def update_rope_for_context_parallel(cos, sin):
seq_len, _ = cos.size()
cp_rank, cp_word_size = pgm.process_group_manager.cp_rank, pgm.process_group_manager.cp_world_size
assert seq_len % cp_word_size == 0, f"Input sequence length ({seq_len}) must be divisible by cp_world_size ({cp_word_size})"

View File

@ -13,7 +13,6 @@ import os
import json
import time
import argparse
from src.parallel.context_parallel import parallel_input
import torch.nn.functional as F
import torch, torch.distributed as dist
from torch.optim import AdamW
@ -38,7 +37,6 @@ def train_step(model, data_loader, device):
batch = next(data_loader)
input_ids = batch["input_ids"].to(device)
target_ids = batch["target_ids"].to(device)
input_ids, target_ids = parallel_input(input_ids, target_ids) # for context parallel, we need to split the input
# disable gradient synchronization for all but the last micro-batch
if requires_grad_sync:
@ -68,7 +66,7 @@ if __name__ == "__main__":
os.environ["OMP_NUM_THREADS"] = config["environment"]["OMP_NUM_THREADS"]
os.environ["TOKENIZERS_PARALLELISM"] = config["environment"]["TOKENIZERS_PARALLELISM"]
os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.float16!
os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.bfloat16!
os.environ["DEVICE"] = "cpu" if config["distributed"]["use_cpu"] else "cuda"
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not config["distributed"]["use_cpu"] else torch.float32 # if GPU is not available or not supported, use torch.float32