fix spliting input twice for context parallel (done in dataloader)
This commit is contained in:
parent
363dbd5c05
commit
f6c9a39d17
4
model.py
4
model.py
@ -2,7 +2,7 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from src.parallel.context_parallel import ring_attention, update_rope
|
||||
from src.parallel.context_parallel import ring_attention, update_rope_for_context_parallel
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
||||
@ -161,7 +161,7 @@ class DecoderLayer(nn.Module):
|
||||
self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim]
|
||||
|
||||
# For context parallelism, we split the input. We need to get the correct cos and sin for each split
|
||||
self.cos, self.sin = update_rope(self.cos, self.sin)
|
||||
self.cos, self.sin = update_rope_for_context_parallel(self.cos, self.sin)
|
||||
|
||||
def forward(self, x, attention_mask = None, position_ids = None):
|
||||
#TODO: Use the default position_ids for RoPE during training. If we have time, work on generation
|
||||
|
||||
@ -181,19 +181,7 @@ def update_out_and_lse(
|
||||
|
||||
return out, lse
|
||||
|
||||
def parallel_input(input_ids, target_ids):
|
||||
cp_rank, cp_word_size = pgm.process_group_manager.cp_rank, pgm.process_group_manager.cp_world_size
|
||||
batch_size, seq_length = input_ids.size()
|
||||
assert seq_length % cp_word_size == 0, f"Input sequence length ({seq_length}) must be divisible by cp_world_size ({cp_word_size})"
|
||||
size_per_partition = seq_length // cp_word_size
|
||||
|
||||
# Calculate start and end indices for this rank
|
||||
start_idx, end_idx = cp_rank * size_per_partition, (cp_rank + 1) * size_per_partition
|
||||
local_input_ids = input_ids[:, start_idx:end_idx]
|
||||
local_target_ids = target_ids[:, start_idx:end_idx]
|
||||
return local_input_ids, local_target_ids
|
||||
|
||||
def update_rope(cos, sin):
|
||||
def update_rope_for_context_parallel(cos, sin):
|
||||
seq_len, _ = cos.size()
|
||||
cp_rank, cp_word_size = pgm.process_group_manager.cp_rank, pgm.process_group_manager.cp_world_size
|
||||
assert seq_len % cp_word_size == 0, f"Input sequence length ({seq_len}) must be divisible by cp_world_size ({cp_word_size})"
|
||||
|
||||
4
train.py
4
train.py
@ -13,7 +13,6 @@ import os
|
||||
import json
|
||||
import time
|
||||
import argparse
|
||||
from src.parallel.context_parallel import parallel_input
|
||||
import torch.nn.functional as F
|
||||
import torch, torch.distributed as dist
|
||||
from torch.optim import AdamW
|
||||
@ -38,7 +37,6 @@ def train_step(model, data_loader, device):
|
||||
batch = next(data_loader)
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
target_ids = batch["target_ids"].to(device)
|
||||
input_ids, target_ids = parallel_input(input_ids, target_ids) # for context parallel, we need to split the input
|
||||
|
||||
# disable gradient synchronization for all but the last micro-batch
|
||||
if requires_grad_sync:
|
||||
@ -68,7 +66,7 @@ if __name__ == "__main__":
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = config["environment"]["OMP_NUM_THREADS"]
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = config["environment"]["TOKENIZERS_PARALLELISM"]
|
||||
os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.float16!
|
||||
os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.bfloat16!
|
||||
os.environ["DEVICE"] = "cpu" if config["distributed"]["use_cpu"] else "cuda"
|
||||
|
||||
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not config["distributed"]["use_cpu"] else torch.float32 # if GPU is not available or not supported, use torch.float32
|
||||
|
||||
Loading…
Reference in New Issue
Block a user