fix spliting input twice for context parallel (done in dataloader)
This commit is contained in:
parent
363dbd5c05
commit
f6c9a39d17
4
model.py
4
model.py
@ -2,7 +2,7 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from src.parallel.context_parallel import ring_attention, update_rope
|
from src.parallel.context_parallel import ring_attention, update_rope_for_context_parallel
|
||||||
from flash_attn.flash_attn_interface import flash_attn_func
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
from flash_attn.layers.rotary import apply_rotary_emb
|
from flash_attn.layers.rotary import apply_rotary_emb
|
||||||
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
||||||
@ -161,7 +161,7 @@ class DecoderLayer(nn.Module):
|
|||||||
self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim]
|
self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim]
|
||||||
|
|
||||||
# For context parallelism, we split the input. We need to get the correct cos and sin for each split
|
# For context parallelism, we split the input. We need to get the correct cos and sin for each split
|
||||||
self.cos, self.sin = update_rope(self.cos, self.sin)
|
self.cos, self.sin = update_rope_for_context_parallel(self.cos, self.sin)
|
||||||
|
|
||||||
def forward(self, x, attention_mask = None, position_ids = None):
|
def forward(self, x, attention_mask = None, position_ids = None):
|
||||||
#TODO: Use the default position_ids for RoPE during training. If we have time, work on generation
|
#TODO: Use the default position_ids for RoPE during training. If we have time, work on generation
|
||||||
|
|||||||
@ -181,19 +181,7 @@ def update_out_and_lse(
|
|||||||
|
|
||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
def parallel_input(input_ids, target_ids):
|
def update_rope_for_context_parallel(cos, sin):
|
||||||
cp_rank, cp_word_size = pgm.process_group_manager.cp_rank, pgm.process_group_manager.cp_world_size
|
|
||||||
batch_size, seq_length = input_ids.size()
|
|
||||||
assert seq_length % cp_word_size == 0, f"Input sequence length ({seq_length}) must be divisible by cp_world_size ({cp_word_size})"
|
|
||||||
size_per_partition = seq_length // cp_word_size
|
|
||||||
|
|
||||||
# Calculate start and end indices for this rank
|
|
||||||
start_idx, end_idx = cp_rank * size_per_partition, (cp_rank + 1) * size_per_partition
|
|
||||||
local_input_ids = input_ids[:, start_idx:end_idx]
|
|
||||||
local_target_ids = target_ids[:, start_idx:end_idx]
|
|
||||||
return local_input_ids, local_target_ids
|
|
||||||
|
|
||||||
def update_rope(cos, sin):
|
|
||||||
seq_len, _ = cos.size()
|
seq_len, _ = cos.size()
|
||||||
cp_rank, cp_word_size = pgm.process_group_manager.cp_rank, pgm.process_group_manager.cp_world_size
|
cp_rank, cp_word_size = pgm.process_group_manager.cp_rank, pgm.process_group_manager.cp_world_size
|
||||||
assert seq_len % cp_word_size == 0, f"Input sequence length ({seq_len}) must be divisible by cp_world_size ({cp_word_size})"
|
assert seq_len % cp_word_size == 0, f"Input sequence length ({seq_len}) must be divisible by cp_world_size ({cp_word_size})"
|
||||||
|
|||||||
4
train.py
4
train.py
@ -13,7 +13,6 @@ import os
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
from src.parallel.context_parallel import parallel_input
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch, torch.distributed as dist
|
import torch, torch.distributed as dist
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
@ -38,7 +37,6 @@ def train_step(model, data_loader, device):
|
|||||||
batch = next(data_loader)
|
batch = next(data_loader)
|
||||||
input_ids = batch["input_ids"].to(device)
|
input_ids = batch["input_ids"].to(device)
|
||||||
target_ids = batch["target_ids"].to(device)
|
target_ids = batch["target_ids"].to(device)
|
||||||
input_ids, target_ids = parallel_input(input_ids, target_ids) # for context parallel, we need to split the input
|
|
||||||
|
|
||||||
# disable gradient synchronization for all but the last micro-batch
|
# disable gradient synchronization for all but the last micro-batch
|
||||||
if requires_grad_sync:
|
if requires_grad_sync:
|
||||||
@ -68,7 +66,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
os.environ["OMP_NUM_THREADS"] = config["environment"]["OMP_NUM_THREADS"]
|
os.environ["OMP_NUM_THREADS"] = config["environment"]["OMP_NUM_THREADS"]
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = config["environment"]["TOKENIZERS_PARALLELISM"]
|
os.environ["TOKENIZERS_PARALLELISM"] = config["environment"]["TOKENIZERS_PARALLELISM"]
|
||||||
os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.float16!
|
os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.bfloat16!
|
||||||
os.environ["DEVICE"] = "cpu" if config["distributed"]["use_cpu"] else "cuda"
|
os.environ["DEVICE"] = "cpu" if config["distributed"]["use_cpu"] else "cuda"
|
||||||
|
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not config["distributed"]["use_cpu"] else torch.float32 # if GPU is not available or not supported, use torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not config["distributed"]["use_cpu"] else torch.float32 # if GPU is not available or not supported, use torch.float32
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user