picotron top level folder

This commit is contained in:
ferdinand.mom 2024-11-04 15:29:26 +00:00
parent e7b4722160
commit 8af19d0caa
18 changed files with 27 additions and 24 deletions

5
.gitignore vendored
View File

@ -3,4 +3,7 @@ __pycache__
.vscode/
picotron.egg-info
*.ipynb
wandb
wandb
tmp
debug
bench

View File

@ -2,11 +2,11 @@ import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.parallel.context_parallel import ring_attention, update_rope_for_context_parallel
from picotron.parallel.context_parallel import ring_attention, update_rope_for_context_parallel
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.layers.rotary import apply_rotary_emb
from flash_attn.ops.triton.layer_norm import layer_norm_fn
import src.distributed.process_group_manager as pgm
import picotron.process_group_manager as pgm
def apply_rotary_pos_emb(x, cos, sin):
#TODO: Maybe do class RotaryEmbedding(nn.Module) later

View File

@ -1,5 +1,5 @@
import os
import src.distributed.process_group_manager as pgm
import picotron.process_group_manager as pgm
from typing import List, Optional
import torch, torch.distributed as dist

View File

@ -4,8 +4,8 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import distributed as dist
from typing import Any, Optional, Tuple
from src.distributed.distributed_primtives import ContextComms
import src.distributed.process_group_manager as pgm
from picotron.distributed.distributed_primtives import ContextComms
import picotron.process_group_manager as pgm
def ring_attention(q, k, v, sm_scale, is_causal):
return RingAttentionFunc.apply(q, k, v, sm_scale, is_causal)

View File

@ -2,7 +2,7 @@ import contextlib
import torch
import torch.distributed as dist
from torch import nn
import src.distributed.process_group_manager as pgm
import picotron.process_group_manager as pgm
class DataParallel(nn.Module):
def __init__(self, module):

View File

@ -3,8 +3,8 @@ import torch
import torch.distributed as dist
from torch import nn
from torch.autograd import Variable
from src.parallel.data_parallel.bucket import BucketManager
import src.distributed.process_group_manager as pgm
from picotron.parallel.data_parallel.bucket import BucketManager
import picotron.process_group_manager as pgm
class DataParallel(nn.Module):
def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32):

View File

@ -1,5 +1,5 @@
import src.distributed.process_group_manager as pgm
from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate
import picotron.process_group_manager as pgm
from picotron.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate
import torch, torch.nn as nn, torch.nn.functional as F
import os

View File

@ -2,15 +2,15 @@
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
"""
from src.parallel.tensor_parallel.utils import VocabUtility
from picotron.parallel.tensor_parallel.utils import VocabUtility
import torch
import math
import torch.nn.init as init
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from typing import Callable, Optional
import src.distributed.process_group_manager as pgm
from src.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
import picotron.process_group_manager as pgm
from picotron.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
def initialize_weight_tensor(weight, vocab_embedding=False):
"""

View File

@ -2,10 +2,10 @@
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
"""
from src.parallel.tensor_parallel.utils import split_tensor_along_last_dim
from picotron.parallel.tensor_parallel.utils import split_tensor_along_last_dim
import torch.distributed as dist
import torch
import src.distributed.process_group_manager as pgm
import picotron.process_group_manager as pgm
def _reduce(input_):
"""All-reduce the input tensor across model parallel(Tensor Parallel) group."""

View File

@ -1,5 +1,5 @@
from functools import partial
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
from picotron.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
import torch.nn.init as init
import torch.nn as nn

View File

@ -20,15 +20,15 @@ import torch, torch.distributed as dist
from torch.optim import AdamW
from transformers import AutoConfig
import numpy as np
from src.parallel.tensor_parallel.tensor_parallel import TensorParallel
import src.distributed.process_group_manager as pgm
from picotron.parallel.tensor_parallel.tensor_parallel import TensorParallel
import picotron.process_group_manager as pgm
from utils import MicroBatchDataLoader, set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
from src.distributed.process_group_manager import setup_process_group_manager
from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from src.parallel.data_parallel.data_parallel_bucket import DataParallel
from picotron.process_group_manager import setup_process_group_manager
from picotron.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from picotron.parallel.data_parallel.data_parallel_bucket import DataParallel
from model import Llama
import wandb
from src.distributed.distributed_primtives import all_reduce_loss_across_dp_cp_ranks
from picotron.distributed.distributed_primtives import all_reduce_loss_across_dp_cp_ranks
def train_step(model, data_loader, device):
acc_loss = 0.0

View File

@ -4,7 +4,7 @@ import os
import numpy as np
import builtins
import fcntl
import src.distributed.process_group_manager as pgm
import picotron.process_group_manager as pgm
import torch
from torch.utils.data import DataLoader, DistributedSampler
from functools import partial