diff --git a/.gitignore b/.gitignore index 35c90d0..19c7345 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,7 @@ __pycache__ .vscode/ picotron.egg-info *.ipynb -wandb \ No newline at end of file +wandb +tmp +debug +bench \ No newline at end of file diff --git a/model.py b/model.py index e7bf4ee..a65cb08 100644 --- a/model.py +++ b/model.py @@ -2,11 +2,11 @@ import os import torch import torch.nn as nn import torch.nn.functional as F -from src.parallel.context_parallel import ring_attention, update_rope_for_context_parallel +from picotron.parallel.context_parallel import ring_attention, update_rope_for_context_parallel from flash_attn.flash_attn_interface import flash_attn_func from flash_attn.layers.rotary import apply_rotary_emb from flash_attn.ops.triton.layer_norm import layer_norm_fn -import src.distributed.process_group_manager as pgm +import picotron.process_group_manager as pgm def apply_rotary_pos_emb(x, cos, sin): #TODO: Maybe do class RotaryEmbedding(nn.Module) later diff --git a/src/__init__.py b/picotron/__init__.py similarity index 100% rename from src/__init__.py rename to picotron/__init__.py diff --git a/src/distributed/distributed_primtives.py b/picotron/distributed/distributed_primtives.py similarity index 99% rename from src/distributed/distributed_primtives.py rename to picotron/distributed/distributed_primtives.py index b054189..ee64d29 100644 --- a/src/distributed/distributed_primtives.py +++ b/picotron/distributed/distributed_primtives.py @@ -1,5 +1,5 @@ import os -import src.distributed.process_group_manager as pgm +import picotron.process_group_manager as pgm from typing import List, Optional import torch, torch.distributed as dist diff --git a/src/parallel/context_parallel.py b/picotron/parallel/context_parallel.py similarity index 98% rename from src/parallel/context_parallel.py rename to picotron/parallel/context_parallel.py index 10a5b89..712bd18 100644 --- a/src/parallel/context_parallel.py +++ b/picotron/parallel/context_parallel.py @@ -4,8 +4,8 @@ import torch.nn as nn import torch.nn.functional as F from torch import distributed as dist from typing import Any, Optional, Tuple -from src.distributed.distributed_primtives import ContextComms -import src.distributed.process_group_manager as pgm +from picotron.distributed.distributed_primtives import ContextComms +import picotron.process_group_manager as pgm def ring_attention(q, k, v, sm_scale, is_causal): return RingAttentionFunc.apply(q, k, v, sm_scale, is_causal) diff --git a/src/parallel/data_parallel/__init__.py b/picotron/parallel/data_parallel/__init__.py similarity index 100% rename from src/parallel/data_parallel/__init__.py rename to picotron/parallel/data_parallel/__init__.py diff --git a/src/parallel/data_parallel/bucket.py b/picotron/parallel/data_parallel/bucket.py similarity index 100% rename from src/parallel/data_parallel/bucket.py rename to picotron/parallel/data_parallel/bucket.py diff --git a/src/parallel/data_parallel/data_parallel.py b/picotron/parallel/data_parallel/data_parallel.py similarity index 97% rename from src/parallel/data_parallel/data_parallel.py rename to picotron/parallel/data_parallel/data_parallel.py index a590b38..9bbbef6 100644 --- a/src/parallel/data_parallel/data_parallel.py +++ b/picotron/parallel/data_parallel/data_parallel.py @@ -2,7 +2,7 @@ import contextlib import torch import torch.distributed as dist from torch import nn -import src.distributed.process_group_manager as pgm +import picotron.process_group_manager as pgm class DataParallel(nn.Module): def __init__(self, module): diff --git a/src/parallel/data_parallel/data_parallel_bucket.py b/picotron/parallel/data_parallel/data_parallel_bucket.py similarity index 98% rename from src/parallel/data_parallel/data_parallel_bucket.py rename to picotron/parallel/data_parallel/data_parallel_bucket.py index 4423d6f..90956f3 100644 --- a/src/parallel/data_parallel/data_parallel_bucket.py +++ b/picotron/parallel/data_parallel/data_parallel_bucket.py @@ -3,8 +3,8 @@ import torch import torch.distributed as dist from torch import nn from torch.autograd import Variable -from src.parallel.data_parallel.bucket import BucketManager -import src.distributed.process_group_manager as pgm +from picotron.parallel.data_parallel.bucket import BucketManager +import picotron.process_group_manager as pgm class DataParallel(nn.Module): def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32): diff --git a/src/parallel/pipeline_parallel.py b/picotron/parallel/pipeline_parallel.py similarity index 97% rename from src/parallel/pipeline_parallel.py rename to picotron/parallel/pipeline_parallel.py index 84663b2..f22bfd7 100644 --- a/src/parallel/pipeline_parallel.py +++ b/picotron/parallel/pipeline_parallel.py @@ -1,5 +1,5 @@ -import src.distributed.process_group_manager as pgm -from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate +import picotron.process_group_manager as pgm +from picotron.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate import torch, torch.nn as nn, torch.nn.functional as F import os diff --git a/src/parallel/tensor_parallel/__init__.py b/picotron/parallel/tensor_parallel/__init__.py similarity index 100% rename from src/parallel/tensor_parallel/__init__.py rename to picotron/parallel/tensor_parallel/__init__.py diff --git a/src/parallel/tensor_parallel/layers.py b/picotron/parallel/tensor_parallel/layers.py similarity index 97% rename from src/parallel/tensor_parallel/layers.py rename to picotron/parallel/tensor_parallel/layers.py index 148f60a..14a42ef 100644 --- a/src/parallel/tensor_parallel/layers.py +++ b/picotron/parallel/tensor_parallel/layers.py @@ -2,15 +2,15 @@ Inspired by Fair Scale/Megatron's Tensor Parallelism implementation Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale """ -from src.parallel.tensor_parallel.utils import VocabUtility +from picotron.parallel.tensor_parallel.utils import VocabUtility import torch import math import torch.nn.init as init import torch.nn.functional as F from torch.nn.parameter import Parameter from typing import Callable, Optional -import src.distributed.process_group_manager as pgm -from src.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region +import picotron.process_group_manager as pgm +from picotron.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region def initialize_weight_tensor(weight, vocab_embedding=False): """ diff --git a/src/parallel/tensor_parallel/mappings.py b/picotron/parallel/tensor_parallel/mappings.py similarity index 96% rename from src/parallel/tensor_parallel/mappings.py rename to picotron/parallel/tensor_parallel/mappings.py index f0e4aed..fa18b7b 100644 --- a/src/parallel/tensor_parallel/mappings.py +++ b/picotron/parallel/tensor_parallel/mappings.py @@ -2,10 +2,10 @@ Inspired by Fair Scale/Megatron's Tensor Parallelism implementation Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale """ -from src.parallel.tensor_parallel.utils import split_tensor_along_last_dim +from picotron.parallel.tensor_parallel.utils import split_tensor_along_last_dim import torch.distributed as dist import torch -import src.distributed.process_group_manager as pgm +import picotron.process_group_manager as pgm def _reduce(input_): """All-reduce the input tensor across model parallel(Tensor Parallel) group.""" diff --git a/src/parallel/tensor_parallel/tensor_parallel.py b/picotron/parallel/tensor_parallel/tensor_parallel.py similarity index 93% rename from src/parallel/tensor_parallel/tensor_parallel.py rename to picotron/parallel/tensor_parallel/tensor_parallel.py index d6712d2..113e664 100644 --- a/src/parallel/tensor_parallel/tensor_parallel.py +++ b/picotron/parallel/tensor_parallel/tensor_parallel.py @@ -1,5 +1,5 @@ from functools import partial -from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor +from picotron.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor import torch.nn.init as init import torch.nn as nn diff --git a/src/parallel/tensor_parallel/utils.py b/picotron/parallel/tensor_parallel/utils.py similarity index 100% rename from src/parallel/tensor_parallel/utils.py rename to picotron/parallel/tensor_parallel/utils.py diff --git a/src/distributed/process_group_manager.py b/picotron/process_group_manager.py similarity index 100% rename from src/distributed/process_group_manager.py rename to picotron/process_group_manager.py diff --git a/train.py b/train.py index 309653a..593ffda 100644 --- a/train.py +++ b/train.py @@ -20,15 +20,15 @@ import torch, torch.distributed as dist from torch.optim import AdamW from transformers import AutoConfig import numpy as np -from src.parallel.tensor_parallel.tensor_parallel import TensorParallel -import src.distributed.process_group_manager as pgm +from picotron.parallel.tensor_parallel.tensor_parallel import TensorParallel +import picotron.process_group_manager as pgm from utils import MicroBatchDataLoader, set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint -from src.distributed.process_group_manager import setup_process_group_manager -from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel -from src.parallel.data_parallel.data_parallel_bucket import DataParallel +from picotron.process_group_manager import setup_process_group_manager +from picotron.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel +from picotron.parallel.data_parallel.data_parallel_bucket import DataParallel from model import Llama import wandb -from src.distributed.distributed_primtives import all_reduce_loss_across_dp_cp_ranks +from picotron.distributed.distributed_primtives import all_reduce_loss_across_dp_cp_ranks def train_step(model, data_loader, device): acc_loss = 0.0 diff --git a/utils.py b/utils.py index 6a1cf4a..dc63ea0 100644 --- a/utils.py +++ b/utils.py @@ -4,7 +4,7 @@ import os import numpy as np import builtins import fcntl -import src.distributed.process_group_manager as pgm +import picotron.process_group_manager as pgm import torch from torch.utils.data import DataLoader, DistributedSampler from functools import partial