folder refactoring + split cp & pp communications near implementation details

This commit is contained in:
ferdinand.mom 2024-11-04 16:10:47 +00:00
parent 1a000975be
commit db926026a6
16 changed files with 133 additions and 132 deletions

View File

@ -1,11 +1,10 @@
# Inspired by https://github.com/zhuzilin/ring-flash-attention
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributed as dist
from typing import Any, Optional, Tuple
from picotron.distributed.distributed_primtives import ContextComms
import picotron.process_group_manager as pgm
from picotron.context_parallel.cp_communications import ContextCommunicate
def ring_attention(q, k, v, sm_scale, is_causal):
return RingAttentionFunc.apply(q, k, v, sm_scale, is_causal)
@ -14,7 +13,7 @@ class RingAttentionFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale, is_causal):
comm = ContextComms("comm")
comm = ContextCommunicate("comm")
#TODO(fmom): add flash attention
#TODO(fmom): Find a better to save these tensors without cloning
k_og = k.clone()
@ -52,8 +51,8 @@ class RingAttentionFunc(torch.autograd.Function):
sm_scale = ctx.sm_scale
is_causal = ctx.is_causal
kv_comm = ContextComms("kv_comm")
d_kv_comm = ContextComms("d_kv_comm")
kv_comm = ContextCommunicate("kv_comm")
d_kv_comm = ContextCommunicate("d_kv_comm")
dq, dk, dv = None, None, None
next_dk, next_dv = None, None
@ -187,4 +186,4 @@ def update_rope_for_context_parallel(cos, sin):
assert seq_len % cp_word_size == 0, f"Input sequence length ({seq_len}) must be divisible by cp_world_size ({cp_word_size})"
size_per_partition = seq_len // cp_word_size
start_idx, end_idx = cp_rank * size_per_partition, (cp_rank + 1) * size_per_partition
return cos[start_idx:end_idx], sin[start_idx:end_idx]
return cos[start_idx:end_idx], sin[start_idx:end_idx]

View File

@ -0,0 +1,54 @@
import os
import torch
from torch import distributed as dist
from typing import List
import picotron.process_group_manager as pgm
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
class ContextCommunicate:
def __init__(self, msg: str = ""):
global STEP
global VERBOSE
self._pending_operations: List[dist.P2POp] = []
self._active_requests = None
self.rank = pgm.process_group_manager.cp_rank
self.world_size = pgm.process_group_manager.cp_world_size
self.send_rank = pgm.process_group_manager.cp_send_rank
self.recv_rank = pgm.process_group_manager.cp_recv_rank
if VERBOSE: print(f"RingComm ({msg}) | initialized | RANK:{self.rank} | "f"WORLD_SIZE:{self.world_size} | SEND_RANK:{self.send_rank} | "f"RECV_RANK:{self.recv_rank}", flush=True)
def send_recv(self, tensor_to_send, recv_tensor=None):
if recv_tensor is None:
result_tensor = torch.zeros_like(tensor_to_send)
else:
result_tensor = recv_tensor
send_operation = dist.P2POp(dist.isend, tensor_to_send, self.send_rank, group=pgm.process_group_manager.cp_group)
recv_operation = dist.P2POp(dist.irecv, result_tensor, self.recv_rank, group=pgm.process_group_manager.cp_group)
self._pending_operations.extend([send_operation, recv_operation])
if VERBOSE:
print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:sending | TO:{self.send_rank} | TENSOR:{tensor_to_send}", flush=True)
print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:receiving | FROM:{self.recv_rank} | TENSOR:{result_tensor}", flush=True)
return result_tensor
def commit(self):
if self._active_requests is not None: raise RuntimeError("Commit called twice")
self._active_requests = dist.batch_isend_irecv(self._pending_operations)
if VERBOSE: print(f"RingComm | commit | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:committed | NUM_OPS:{len(self._pending_operations) // 2}", flush=True)
def wait(self):
if self._active_requests is None: raise RuntimeError("Wait called before commit")
for i, request in enumerate(self._active_requests):
request.wait()
if VERBOSE:
operation_type = "send" if i % 2 == 0 else "receive"
peer_rank = self.send_rank if operation_type == "send" else self.recv_rank
print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:completed_{operation_type} | "f"{'FROM' if operation_type == 'receive' else 'TO'}:{peer_rank}", flush=True)
torch.cuda.synchronize()
self._active_requests = None
self._pending_operations = []
if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True)

View File

@ -1,9 +1,9 @@
import contextlib
import torch
import torch.distributed as dist
import contextlib
from torch import nn
from torch.autograd import Variable
from picotron.parallel.data_parallel.bucket import BucketManager
from picotron.data_parallel.bucket import BucketManager
import picotron.process_group_manager as pgm
class DataParallel(nn.Module):

View File

@ -1,107 +0,0 @@
import os
import picotron.process_group_manager as pgm
from typing import List, Optional
import torch, torch.distributed as dist
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None):
global STEP
global VERBOSE
if operation == 'recv_forward':
if pgm.process_group_manager.pp_is_first_stage: return None
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
src = pgm.process_group_manager.pp_prev_rank
elif operation == 'send_forward':
if pgm.process_group_manager.pp_is_last_stage: return
dest = pgm.process_group_manager.pp_next_rank
elif operation == 'recv_backward':
if pgm.process_group_manager.pp_is_last_stage: return None
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
src = pgm.process_group_manager.pp_next_rank
elif operation == 'send_backward':
if pgm.process_group_manager.pp_is_first_stage: return
dest = pgm.process_group_manager.pp_prev_rank
is_send = operation.startswith('send')
peer_rank = dest if is_send else src
op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank)
if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pgm.process_group_manager.pp_rank} {'' if is_send else ''} {peer_rank} | STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
[req.wait() for req in dist.batch_isend_irecv([op])]
torch.cuda.synchronize()
if VERBOSE: STEP += 1
return tensor if not is_send else None
def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype):
global STEP
global VERBOSE
is_fwd = (operation == 'send_fwd_recv_bwd')
if (is_fwd and pgm.process_group_manager.pp_is_last_stage) or (not is_fwd and pgm.process_group_manager.pp_is_first_stage): return None
peer_rank = pgm.process_group_manager.pp_next_rank if is_fwd else pgm.process_group_manager.pp_prev_rank
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)])
if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pgm.process_group_manager.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pgm.process_group_manager.pp_rank} | "f"STEP {STEP=} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
[req.wait() for req in reqs]
torch.cuda.synchronize()
if VERBOSE: STEP += 1
return recv_tensor
class ContextComms:
def __init__(self, msg: str = ""):
global STEP
global VERBOSE
self._pending_operations: List[dist.P2POp] = []
self._active_requests = None
self.rank = pgm.process_group_manager.cp_rank
self.world_size = pgm.process_group_manager.cp_world_size
self.send_rank = pgm.process_group_manager.cp_send_rank
self.recv_rank = pgm.process_group_manager.cp_recv_rank
if VERBOSE: print(f"RingComm ({msg}) | initialized | RANK:{self.rank} | "f"WORLD_SIZE:{self.world_size} | SEND_RANK:{self.send_rank} | "f"RECV_RANK:{self.recv_rank}", flush=True)
def send_recv(self, tensor_to_send, recv_tensor=None):
if recv_tensor is None:
result_tensor = torch.zeros_like(tensor_to_send)
else:
result_tensor = recv_tensor
send_operation = dist.P2POp(dist.isend, tensor_to_send, self.send_rank, group=pgm.process_group_manager.cp_group)
recv_operation = dist.P2POp(dist.irecv, result_tensor, self.recv_rank, group=pgm.process_group_manager.cp_group)
self._pending_operations.extend([send_operation, recv_operation])
if VERBOSE:
print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:sending | TO:{self.send_rank} | TENSOR:{tensor_to_send}", flush=True)
print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:receiving | FROM:{self.recv_rank} | TENSOR:{result_tensor}", flush=True)
return result_tensor
def commit(self):
if self._active_requests is not None: raise RuntimeError("Commit called twice")
self._active_requests = dist.batch_isend_irecv(self._pending_operations)
if VERBOSE: print(f"RingComm | commit | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:committed | NUM_OPS:{len(self._pending_operations) // 2}", flush=True)
def wait(self):
if self._active_requests is None: raise RuntimeError("Wait called before commit")
for i, request in enumerate(self._active_requests):
request.wait()
if VERBOSE:
operation_type = "send" if i % 2 == 0 else "receive"
peer_rank = self.send_rank if operation_type == "send" else self.recv_rank
print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:completed_{operation_type} | "f"{'FROM' if operation_type == 'receive' else 'TO'}:{peer_rank}", flush=True)
torch.cuda.synchronize()
self._active_requests = None
self._pending_operations = []
if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True)
def all_reduce_loss_across_dp_cp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
# only the last stage of the pipeline parallelism contains the loss
# we need to average the loss among the data/context parallel group
if pgm.process_group_manager.pp_is_last_stage:
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
return reduced_loss.item()
def all_reduce_gradients_across_dp_cp_ranks(model):
for param in model.parameters():
if param.grad is not None:
# Average the gradients across all DP & CP ranks
param.grad /= pgm.process_group_manager.cp_dp_world_size
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group)

View File

@ -2,7 +2,7 @@ import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from picotron.parallel.context_parallel import ring_attention, update_rope_for_context_parallel
from picotron.context_parallel import context_parallel
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.layers.rotary import apply_rotary_emb
from flash_attn.ops.triton.layer_norm import layer_norm_fn
@ -122,7 +122,7 @@ class Attention(nn.Module):
if pgm.process_group_manager.cp_world_size > 1:
# Ring attention for context parallelism
sm_scale = 1.0 / (q.size(-1) ** 0.5)
out = ring_attention(q, k, v, sm_scale, causal).transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim]
out = context_parallel.ring_attention(q, k, v, sm_scale, causal).transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim]
elif os.getenv('FLASH_ATTEN', '1') == '1':
# flash attention, this is faster!
out = flash_attention(q, k, v, causal = causal) # [batch_size, seq_length, num_heads, head_dim]
@ -161,7 +161,7 @@ class DecoderLayer(nn.Module):
self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim]
# For context parallelism, we split the input. We need to get the correct cos and sin for each split
self.cos, self.sin = update_rope_for_context_parallel(self.cos, self.sin)
self.cos, self.sin = context_parallel.update_rope_for_context_parallel(self.cos, self.sin)
def forward(self, x, attention_mask = None, position_ids = None):
#TODO: Use the default position_ids for RoPE during training. If we have time, work on generation

View File

@ -1,7 +1,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import picotron.process_group_manager as pgm
from picotron.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate
import torch, torch.nn as nn, torch.nn.functional as F
import os
from picotron.pipeline_parallel.pp_communications import pipeline_communicate, bidirectional_pipeline_communicate
class PipelineParallel(nn.Module):
def __init__(self, model, config):

View File

@ -0,0 +1,46 @@
import os
import torch
import torch.distributed as dist
import picotron.process_group_manager as pgm
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None):
global STEP
global VERBOSE
if operation == 'recv_forward':
if pgm.process_group_manager.pp_is_first_stage: return None
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
src = pgm.process_group_manager.pp_prev_rank
elif operation == 'send_forward':
if pgm.process_group_manager.pp_is_last_stage: return
dest = pgm.process_group_manager.pp_next_rank
elif operation == 'recv_backward':
if pgm.process_group_manager.pp_is_last_stage: return None
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
src = pgm.process_group_manager.pp_next_rank
elif operation == 'send_backward':
if pgm.process_group_manager.pp_is_first_stage: return
dest = pgm.process_group_manager.pp_prev_rank
is_send = operation.startswith('send')
peer_rank = dest if is_send else src
op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank)
if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pgm.process_group_manager.pp_rank} {'' if is_send else ''} {peer_rank} | STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
[req.wait() for req in dist.batch_isend_irecv([op])]
torch.cuda.synchronize()
if VERBOSE: STEP += 1
return tensor if not is_send else None
def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype):
global STEP
global VERBOSE
is_fwd = (operation == 'send_fwd_recv_bwd')
if (is_fwd and pgm.process_group_manager.pp_is_last_stage) or (not is_fwd and pgm.process_group_manager.pp_is_first_stage): return None
peer_rank = pgm.process_group_manager.pp_next_rank if is_fwd else pgm.process_group_manager.pp_prev_rank
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)])
if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pgm.process_group_manager.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pgm.process_group_manager.pp_rank} | "f"STEP {STEP=} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
[req.wait() for req in reqs]
torch.cuda.synchronize()
if VERBOSE: STEP += 1
return recv_tensor

View File

@ -2,7 +2,7 @@
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
"""
from picotron.parallel.tensor_parallel.utils import VocabUtility
from picotron.tensor_parallel.utils import VocabUtility
import torch
import math
import torch.nn.init as init
@ -10,7 +10,7 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter
from typing import Callable, Optional
import picotron.process_group_manager as pgm
from picotron.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
from picotron.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
def initialize_weight_tensor(weight, vocab_embedding=False):
"""

View File

@ -2,7 +2,7 @@
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
"""
from picotron.parallel.tensor_parallel.utils import split_tensor_along_last_dim
from picotron.tensor_parallel.utils import split_tensor_along_last_dim
import torch.distributed as dist
import torch
import picotron.process_group_manager as pgm

View File

@ -1,5 +1,5 @@
from functools import partial
from picotron.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
from picotron.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
import torch.nn.init as init
import torch.nn as nn

View File

@ -20,16 +20,23 @@ import torch, torch.distributed as dist
from torch.optim import AdamW
from transformers import AutoConfig
import numpy as np
from picotron.parallel.tensor_parallel.tensor_parallel import TensorParallel
from picotron.tensor_parallel.tensor_parallel import TensorParallel
import picotron.process_group_manager as pgm
from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
from picotron.data import MicroBatchDataLoader
from picotron.process_group_manager import setup_process_group_manager
from picotron.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from picotron.parallel.data_parallel.data_parallel_bucket import DataParallel
from model import Llama
from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from picotron.data_parallel.data_parallel_bucket import DataParallel
from picotron.model import Llama
import wandb
from picotron.distributed.distributed_primtives import all_reduce_loss_across_dp_cp_ranks
def all_reduce_loss_across_dp_cp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
# only the last stage of the pipeline parallelism contains the loss
# we need to average the loss among the data/context parallel group
if pgm.process_group_manager.pp_is_last_stage:
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
return reduced_loss.item()
def train_step(model, data_loader, device):
acc_loss = 0.0