From db926026a64b7689e9b1f88fa2c91d19dcce2933 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Nov 2024 16:10:47 +0000 Subject: [PATCH] folder refactoring + split cp & pp communications near implementation details --- .../context_parallel.py | 13 +-- .../context_parallel/cp_communications.py | 54 +++++++++ .../{parallel => }/data_parallel/__init__.py | 0 .../{parallel => }/data_parallel/bucket.py | 0 .../data_parallel/data_parallel.py | 0 .../data_parallel/data_parallel_bucket.py | 6 +- picotron/distributed/distributed_primtives.py | 107 ------------------ picotron/model.py | 6 +- .../pipeline_parallel.py | 8 +- .../pipeline_parallel/pp_communications.py | 46 ++++++++ .../tensor_parallel/__init__.py | 0 .../{parallel => }/tensor_parallel/layers.py | 4 +- .../tensor_parallel/mappings.py | 2 +- .../tensor_parallel/tensor_parallel.py | 2 +- .../{parallel => }/tensor_parallel/utils.py | 0 train.py | 17 ++- 16 files changed, 133 insertions(+), 132 deletions(-) rename picotron/{parallel => context_parallel}/context_parallel.py (95%) create mode 100644 picotron/context_parallel/cp_communications.py rename picotron/{parallel => }/data_parallel/__init__.py (100%) rename picotron/{parallel => }/data_parallel/bucket.py (100%) rename picotron/{parallel => }/data_parallel/data_parallel.py (100%) rename picotron/{parallel => }/data_parallel/data_parallel_bucket.py (98%) delete mode 100644 picotron/distributed/distributed_primtives.py rename picotron/{parallel => pipeline_parallel}/pipeline_parallel.py (97%) create mode 100644 picotron/pipeline_parallel/pp_communications.py rename picotron/{parallel => }/tensor_parallel/__init__.py (100%) rename picotron/{parallel => }/tensor_parallel/layers.py (97%) rename picotron/{parallel => }/tensor_parallel/mappings.py (97%) rename picotron/{parallel => }/tensor_parallel/tensor_parallel.py (93%) rename picotron/{parallel => }/tensor_parallel/utils.py (100%) diff --git a/picotron/parallel/context_parallel.py b/picotron/context_parallel/context_parallel.py similarity index 95% rename from picotron/parallel/context_parallel.py rename to picotron/context_parallel/context_parallel.py index 712bd18..b922c99 100644 --- a/picotron/parallel/context_parallel.py +++ b/picotron/context_parallel/context_parallel.py @@ -1,11 +1,10 @@ # Inspired by https://github.com/zhuzilin/ring-flash-attention import torch -import torch.nn as nn import torch.nn.functional as F -from torch import distributed as dist from typing import Any, Optional, Tuple -from picotron.distributed.distributed_primtives import ContextComms + import picotron.process_group_manager as pgm +from picotron.context_parallel.cp_communications import ContextCommunicate def ring_attention(q, k, v, sm_scale, is_causal): return RingAttentionFunc.apply(q, k, v, sm_scale, is_causal) @@ -14,7 +13,7 @@ class RingAttentionFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, sm_scale, is_causal): - comm = ContextComms("comm") + comm = ContextCommunicate("comm") #TODO(fmom): add flash attention #TODO(fmom): Find a better to save these tensors without cloning k_og = k.clone() @@ -52,8 +51,8 @@ class RingAttentionFunc(torch.autograd.Function): sm_scale = ctx.sm_scale is_causal = ctx.is_causal - kv_comm = ContextComms("kv_comm") - d_kv_comm = ContextComms("d_kv_comm") + kv_comm = ContextCommunicate("kv_comm") + d_kv_comm = ContextCommunicate("d_kv_comm") dq, dk, dv = None, None, None next_dk, next_dv = None, None @@ -187,4 +186,4 @@ def update_rope_for_context_parallel(cos, sin): assert seq_len % cp_word_size == 0, f"Input sequence length ({seq_len}) must be divisible by cp_world_size ({cp_word_size})" size_per_partition = seq_len // cp_word_size start_idx, end_idx = cp_rank * size_per_partition, (cp_rank + 1) * size_per_partition - return cos[start_idx:end_idx], sin[start_idx:end_idx] \ No newline at end of file + return cos[start_idx:end_idx], sin[start_idx:end_idx] diff --git a/picotron/context_parallel/cp_communications.py b/picotron/context_parallel/cp_communications.py new file mode 100644 index 0000000..5b3c0ff --- /dev/null +++ b/picotron/context_parallel/cp_communications.py @@ -0,0 +1,54 @@ +import os +import torch +from torch import distributed as dist +from typing import List + +import picotron.process_group_manager as pgm + +STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1" + +class ContextCommunicate: + def __init__(self, msg: str = ""): + global STEP + global VERBOSE + self._pending_operations: List[dist.P2POp] = [] + self._active_requests = None + self.rank = pgm.process_group_manager.cp_rank + self.world_size = pgm.process_group_manager.cp_world_size + self.send_rank = pgm.process_group_manager.cp_send_rank + self.recv_rank = pgm.process_group_manager.cp_recv_rank + if VERBOSE: print(f"RingComm ({msg}) | initialized | RANK:{self.rank} | "f"WORLD_SIZE:{self.world_size} | SEND_RANK:{self.send_rank} | "f"RECV_RANK:{self.recv_rank}", flush=True) + + def send_recv(self, tensor_to_send, recv_tensor=None): + if recv_tensor is None: + result_tensor = torch.zeros_like(tensor_to_send) + else: + result_tensor = recv_tensor + + send_operation = dist.P2POp(dist.isend, tensor_to_send, self.send_rank, group=pgm.process_group_manager.cp_group) + recv_operation = dist.P2POp(dist.irecv, result_tensor, self.recv_rank, group=pgm.process_group_manager.cp_group) + + self._pending_operations.extend([send_operation, recv_operation]) + + if VERBOSE: + print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:sending | TO:{self.send_rank} | TENSOR:{tensor_to_send}", flush=True) + print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:receiving | FROM:{self.recv_rank} | TENSOR:{result_tensor}", flush=True) + return result_tensor + + def commit(self): + if self._active_requests is not None: raise RuntimeError("Commit called twice") + self._active_requests = dist.batch_isend_irecv(self._pending_operations) + if VERBOSE: print(f"RingComm | commit | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:committed | NUM_OPS:{len(self._pending_operations) // 2}", flush=True) + + def wait(self): + if self._active_requests is None: raise RuntimeError("Wait called before commit") + for i, request in enumerate(self._active_requests): + request.wait() + if VERBOSE: + operation_type = "send" if i % 2 == 0 else "receive" + peer_rank = self.send_rank if operation_type == "send" else self.recv_rank + print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:completed_{operation_type} | "f"{'FROM' if operation_type == 'receive' else 'TO'}:{peer_rank}", flush=True) + torch.cuda.synchronize() + self._active_requests = None + self._pending_operations = [] + if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True) \ No newline at end of file diff --git a/picotron/parallel/data_parallel/__init__.py b/picotron/data_parallel/__init__.py similarity index 100% rename from picotron/parallel/data_parallel/__init__.py rename to picotron/data_parallel/__init__.py diff --git a/picotron/parallel/data_parallel/bucket.py b/picotron/data_parallel/bucket.py similarity index 100% rename from picotron/parallel/data_parallel/bucket.py rename to picotron/data_parallel/bucket.py diff --git a/picotron/parallel/data_parallel/data_parallel.py b/picotron/data_parallel/data_parallel.py similarity index 100% rename from picotron/parallel/data_parallel/data_parallel.py rename to picotron/data_parallel/data_parallel.py diff --git a/picotron/parallel/data_parallel/data_parallel_bucket.py b/picotron/data_parallel/data_parallel_bucket.py similarity index 98% rename from picotron/parallel/data_parallel/data_parallel_bucket.py rename to picotron/data_parallel/data_parallel_bucket.py index 90956f3..8f117b3 100644 --- a/picotron/parallel/data_parallel/data_parallel_bucket.py +++ b/picotron/data_parallel/data_parallel_bucket.py @@ -1,9 +1,9 @@ -import contextlib import torch -import torch.distributed as dist +import contextlib from torch import nn from torch.autograd import Variable -from picotron.parallel.data_parallel.bucket import BucketManager + +from picotron.data_parallel.bucket import BucketManager import picotron.process_group_manager as pgm class DataParallel(nn.Module): diff --git a/picotron/distributed/distributed_primtives.py b/picotron/distributed/distributed_primtives.py deleted file mode 100644 index ee64d29..0000000 --- a/picotron/distributed/distributed_primtives.py +++ /dev/null @@ -1,107 +0,0 @@ -import os -import picotron.process_group_manager as pgm -from typing import List, Optional -import torch, torch.distributed as dist - -STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1" - -def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None): - global STEP - global VERBOSE - if operation == 'recv_forward': - if pgm.process_group_manager.pp_is_first_stage: return None - tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype) - src = pgm.process_group_manager.pp_prev_rank - elif operation == 'send_forward': - if pgm.process_group_manager.pp_is_last_stage: return - dest = pgm.process_group_manager.pp_next_rank - elif operation == 'recv_backward': - if pgm.process_group_manager.pp_is_last_stage: return None - tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype) - src = pgm.process_group_manager.pp_next_rank - elif operation == 'send_backward': - if pgm.process_group_manager.pp_is_first_stage: return - dest = pgm.process_group_manager.pp_prev_rank - is_send = operation.startswith('send') - peer_rank = dest if is_send else src - op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank) - if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pgm.process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}", flush=True) - [req.wait() for req in dist.batch_isend_irecv([op])] - torch.cuda.synchronize() - if VERBOSE: STEP += 1 - return tensor if not is_send else None - -def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype): - global STEP - global VERBOSE - is_fwd = (operation == 'send_fwd_recv_bwd') - if (is_fwd and pgm.process_group_manager.pp_is_last_stage) or (not is_fwd and pgm.process_group_manager.pp_is_first_stage): return None - peer_rank = pgm.process_group_manager.pp_next_rank if is_fwd else pgm.process_group_manager.pp_prev_rank - recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype) - reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)]) - if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pgm.process_group_manager.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pgm.process_group_manager.pp_rank} | "f"STEP {STEP=} | RANK:{pgm.process_group_manager.pp_rank}", flush=True) - [req.wait() for req in reqs] - torch.cuda.synchronize() - if VERBOSE: STEP += 1 - return recv_tensor - -class ContextComms: - def __init__(self, msg: str = ""): - global STEP - global VERBOSE - self._pending_operations: List[dist.P2POp] = [] - self._active_requests = None - self.rank = pgm.process_group_manager.cp_rank - self.world_size = pgm.process_group_manager.cp_world_size - self.send_rank = pgm.process_group_manager.cp_send_rank - self.recv_rank = pgm.process_group_manager.cp_recv_rank - if VERBOSE: print(f"RingComm ({msg}) | initialized | RANK:{self.rank} | "f"WORLD_SIZE:{self.world_size} | SEND_RANK:{self.send_rank} | "f"RECV_RANK:{self.recv_rank}", flush=True) - - def send_recv(self, tensor_to_send, recv_tensor=None): - if recv_tensor is None: - result_tensor = torch.zeros_like(tensor_to_send) - else: - result_tensor = recv_tensor - - send_operation = dist.P2POp(dist.isend, tensor_to_send, self.send_rank, group=pgm.process_group_manager.cp_group) - recv_operation = dist.P2POp(dist.irecv, result_tensor, self.recv_rank, group=pgm.process_group_manager.cp_group) - - self._pending_operations.extend([send_operation, recv_operation]) - - if VERBOSE: - print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:sending | TO:{self.send_rank} | TENSOR:{tensor_to_send}", flush=True) - print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:receiving | FROM:{self.recv_rank} | TENSOR:{result_tensor}", flush=True) - return result_tensor - - def commit(self): - if self._active_requests is not None: raise RuntimeError("Commit called twice") - self._active_requests = dist.batch_isend_irecv(self._pending_operations) - if VERBOSE: print(f"RingComm | commit | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:committed | NUM_OPS:{len(self._pending_operations) // 2}", flush=True) - - def wait(self): - if self._active_requests is None: raise RuntimeError("Wait called before commit") - for i, request in enumerate(self._active_requests): - request.wait() - if VERBOSE: - operation_type = "send" if i % 2 == 0 else "receive" - peer_rank = self.send_rank if operation_type == "send" else self.recv_rank - print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:completed_{operation_type} | "f"{'FROM' if operation_type == 'receive' else 'TO'}:{peer_rank}", flush=True) - torch.cuda.synchronize() - self._active_requests = None - self._pending_operations = [] - if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True) - -def all_reduce_loss_across_dp_cp_ranks(loss, device): - reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) - # only the last stage of the pipeline parallelism contains the loss - # we need to average the loss among the data/context parallel group - if pgm.process_group_manager.pp_is_last_stage: - dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group) - return reduced_loss.item() - -def all_reduce_gradients_across_dp_cp_ranks(model): - for param in model.parameters(): - if param.grad is not None: - # Average the gradients across all DP & CP ranks - param.grad /= pgm.process_group_manager.cp_dp_world_size - dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group) \ No newline at end of file diff --git a/picotron/model.py b/picotron/model.py index a65cb08..143cf0b 100644 --- a/picotron/model.py +++ b/picotron/model.py @@ -2,7 +2,7 @@ import os import torch import torch.nn as nn import torch.nn.functional as F -from picotron.parallel.context_parallel import ring_attention, update_rope_for_context_parallel +from picotron.context_parallel import context_parallel from flash_attn.flash_attn_interface import flash_attn_func from flash_attn.layers.rotary import apply_rotary_emb from flash_attn.ops.triton.layer_norm import layer_norm_fn @@ -122,7 +122,7 @@ class Attention(nn.Module): if pgm.process_group_manager.cp_world_size > 1: # Ring attention for context parallelism sm_scale = 1.0 / (q.size(-1) ** 0.5) - out = ring_attention(q, k, v, sm_scale, causal).transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim] + out = context_parallel.ring_attention(q, k, v, sm_scale, causal).transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim] elif os.getenv('FLASH_ATTEN', '1') == '1': # flash attention, this is faster! out = flash_attention(q, k, v, causal = causal) # [batch_size, seq_length, num_heads, head_dim] @@ -161,7 +161,7 @@ class DecoderLayer(nn.Module): self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim] # For context parallelism, we split the input. We need to get the correct cos and sin for each split - self.cos, self.sin = update_rope_for_context_parallel(self.cos, self.sin) + self.cos, self.sin = context_parallel.update_rope_for_context_parallel(self.cos, self.sin) def forward(self, x, attention_mask = None, position_ids = None): #TODO: Use the default position_ids for RoPE during training. If we have time, work on generation diff --git a/picotron/parallel/pipeline_parallel.py b/picotron/pipeline_parallel/pipeline_parallel.py similarity index 97% rename from picotron/parallel/pipeline_parallel.py rename to picotron/pipeline_parallel/pipeline_parallel.py index f22bfd7..3f3e4eb 100644 --- a/picotron/parallel/pipeline_parallel.py +++ b/picotron/pipeline_parallel/pipeline_parallel.py @@ -1,7 +1,9 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + import picotron.process_group_manager as pgm -from picotron.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate -import torch, torch.nn as nn, torch.nn.functional as F -import os +from picotron.pipeline_parallel.pp_communications import pipeline_communicate, bidirectional_pipeline_communicate class PipelineParallel(nn.Module): def __init__(self, model, config): diff --git a/picotron/pipeline_parallel/pp_communications.py b/picotron/pipeline_parallel/pp_communications.py new file mode 100644 index 0000000..fcbd31b --- /dev/null +++ b/picotron/pipeline_parallel/pp_communications.py @@ -0,0 +1,46 @@ +import os +import torch +import torch.distributed as dist +import picotron.process_group_manager as pgm + +STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1" + +def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None): + global STEP + global VERBOSE + if operation == 'recv_forward': + if pgm.process_group_manager.pp_is_first_stage: return None + tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype) + src = pgm.process_group_manager.pp_prev_rank + elif operation == 'send_forward': + if pgm.process_group_manager.pp_is_last_stage: return + dest = pgm.process_group_manager.pp_next_rank + elif operation == 'recv_backward': + if pgm.process_group_manager.pp_is_last_stage: return None + tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype) + src = pgm.process_group_manager.pp_next_rank + elif operation == 'send_backward': + if pgm.process_group_manager.pp_is_first_stage: return + dest = pgm.process_group_manager.pp_prev_rank + is_send = operation.startswith('send') + peer_rank = dest if is_send else src + op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank) + if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pgm.process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}", flush=True) + [req.wait() for req in dist.batch_isend_irecv([op])] + torch.cuda.synchronize() + if VERBOSE: STEP += 1 + return tensor if not is_send else None + +def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype): + global STEP + global VERBOSE + is_fwd = (operation == 'send_fwd_recv_bwd') + if (is_fwd and pgm.process_group_manager.pp_is_last_stage) or (not is_fwd and pgm.process_group_manager.pp_is_first_stage): return None + peer_rank = pgm.process_group_manager.pp_next_rank if is_fwd else pgm.process_group_manager.pp_prev_rank + recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype) + reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)]) + if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pgm.process_group_manager.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pgm.process_group_manager.pp_rank} | "f"STEP {STEP=} | RANK:{pgm.process_group_manager.pp_rank}", flush=True) + [req.wait() for req in reqs] + torch.cuda.synchronize() + if VERBOSE: STEP += 1 + return recv_tensor \ No newline at end of file diff --git a/picotron/parallel/tensor_parallel/__init__.py b/picotron/tensor_parallel/__init__.py similarity index 100% rename from picotron/parallel/tensor_parallel/__init__.py rename to picotron/tensor_parallel/__init__.py diff --git a/picotron/parallel/tensor_parallel/layers.py b/picotron/tensor_parallel/layers.py similarity index 97% rename from picotron/parallel/tensor_parallel/layers.py rename to picotron/tensor_parallel/layers.py index 14a42ef..9dd562d 100644 --- a/picotron/parallel/tensor_parallel/layers.py +++ b/picotron/tensor_parallel/layers.py @@ -2,7 +2,7 @@ Inspired by Fair Scale/Megatron's Tensor Parallelism implementation Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale """ -from picotron.parallel.tensor_parallel.utils import VocabUtility +from picotron.tensor_parallel.utils import VocabUtility import torch import math import torch.nn.init as init @@ -10,7 +10,7 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter from typing import Callable, Optional import picotron.process_group_manager as pgm -from picotron.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region +from picotron.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region def initialize_weight_tensor(weight, vocab_embedding=False): """ diff --git a/picotron/parallel/tensor_parallel/mappings.py b/picotron/tensor_parallel/mappings.py similarity index 97% rename from picotron/parallel/tensor_parallel/mappings.py rename to picotron/tensor_parallel/mappings.py index fa18b7b..90cb356 100644 --- a/picotron/parallel/tensor_parallel/mappings.py +++ b/picotron/tensor_parallel/mappings.py @@ -2,7 +2,7 @@ Inspired by Fair Scale/Megatron's Tensor Parallelism implementation Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale """ -from picotron.parallel.tensor_parallel.utils import split_tensor_along_last_dim +from picotron.tensor_parallel.utils import split_tensor_along_last_dim import torch.distributed as dist import torch import picotron.process_group_manager as pgm diff --git a/picotron/parallel/tensor_parallel/tensor_parallel.py b/picotron/tensor_parallel/tensor_parallel.py similarity index 93% rename from picotron/parallel/tensor_parallel/tensor_parallel.py rename to picotron/tensor_parallel/tensor_parallel.py index 113e664..2f11c19 100644 --- a/picotron/parallel/tensor_parallel/tensor_parallel.py +++ b/picotron/tensor_parallel/tensor_parallel.py @@ -1,5 +1,5 @@ from functools import partial -from picotron.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor +from picotron.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor import torch.nn.init as init import torch.nn as nn diff --git a/picotron/parallel/tensor_parallel/utils.py b/picotron/tensor_parallel/utils.py similarity index 100% rename from picotron/parallel/tensor_parallel/utils.py rename to picotron/tensor_parallel/utils.py diff --git a/train.py b/train.py index 96cd14d..7af0667 100644 --- a/train.py +++ b/train.py @@ -20,16 +20,23 @@ import torch, torch.distributed as dist from torch.optim import AdamW from transformers import AutoConfig import numpy as np -from picotron.parallel.tensor_parallel.tensor_parallel import TensorParallel +from picotron.tensor_parallel.tensor_parallel import TensorParallel import picotron.process_group_manager as pgm from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint from picotron.data import MicroBatchDataLoader from picotron.process_group_manager import setup_process_group_manager -from picotron.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel -from picotron.parallel.data_parallel.data_parallel_bucket import DataParallel -from model import Llama +from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel +from picotron.data_parallel.data_parallel_bucket import DataParallel +from picotron.model import Llama import wandb -from picotron.distributed.distributed_primtives import all_reduce_loss_across_dp_cp_ranks + +def all_reduce_loss_across_dp_cp_ranks(loss, device): + reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) + # only the last stage of the pipeline parallelism contains the loss + # we need to average the loss among the data/context parallel group + if pgm.process_group_manager.pp_is_last_stage: + dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group) + return reduced_loss.item() def train_step(model, data_loader, device): acc_loss = 0.0