picotron top level folder
This commit is contained in:
parent
e7b4722160
commit
8af19d0caa
5
.gitignore
vendored
5
.gitignore
vendored
@ -3,4 +3,7 @@ __pycache__
|
||||
.vscode/
|
||||
picotron.egg-info
|
||||
*.ipynb
|
||||
wandb
|
||||
wandb
|
||||
tmp
|
||||
debug
|
||||
bench
|
||||
4
model.py
4
model.py
@ -2,11 +2,11 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from src.parallel.context_parallel import ring_attention, update_rope_for_context_parallel
|
||||
from picotron.parallel.context_parallel import ring_attention, update_rope_for_context_parallel
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
||||
import src.distributed.process_group_manager as pgm
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
def apply_rotary_pos_emb(x, cos, sin):
|
||||
#TODO: Maybe do class RotaryEmbedding(nn.Module) later
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
import src.distributed.process_group_manager as pgm
|
||||
import picotron.process_group_manager as pgm
|
||||
from typing import List, Optional
|
||||
import torch, torch.distributed as dist
|
||||
|
||||
@ -4,8 +4,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import distributed as dist
|
||||
from typing import Any, Optional, Tuple
|
||||
from src.distributed.distributed_primtives import ContextComms
|
||||
import src.distributed.process_group_manager as pgm
|
||||
from picotron.distributed.distributed_primtives import ContextComms
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
def ring_attention(q, k, v, sm_scale, is_causal):
|
||||
return RingAttentionFunc.apply(q, k, v, sm_scale, is_causal)
|
||||
@ -2,7 +2,7 @@ import contextlib
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
import src.distributed.process_group_manager as pgm
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
class DataParallel(nn.Module):
|
||||
def __init__(self, module):
|
||||
@ -3,8 +3,8 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
from src.parallel.data_parallel.bucket import BucketManager
|
||||
import src.distributed.process_group_manager as pgm
|
||||
from picotron.parallel.data_parallel.bucket import BucketManager
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
class DataParallel(nn.Module):
|
||||
def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32):
|
||||
@ -1,5 +1,5 @@
|
||||
import src.distributed.process_group_manager as pgm
|
||||
from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate
|
||||
import torch, torch.nn as nn, torch.nn.functional as F
|
||||
import os
|
||||
|
||||
@ -2,15 +2,15 @@
|
||||
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
|
||||
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
"""
|
||||
from src.parallel.tensor_parallel.utils import VocabUtility
|
||||
from picotron.parallel.tensor_parallel.utils import VocabUtility
|
||||
import torch
|
||||
import math
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Callable, Optional
|
||||
import src.distributed.process_group_manager as pgm
|
||||
from src.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
|
||||
|
||||
def initialize_weight_tensor(weight, vocab_embedding=False):
|
||||
"""
|
||||
@ -2,10 +2,10 @@
|
||||
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
|
||||
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
"""
|
||||
from src.parallel.tensor_parallel.utils import split_tensor_along_last_dim
|
||||
from picotron.parallel.tensor_parallel.utils import split_tensor_along_last_dim
|
||||
import torch.distributed as dist
|
||||
import torch
|
||||
import src.distributed.process_group_manager as pgm
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
def _reduce(input_):
|
||||
"""All-reduce the input tensor across model parallel(Tensor Parallel) group."""
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import partial
|
||||
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
|
||||
from picotron.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
|
||||
import torch.nn.init as init
|
||||
import torch.nn as nn
|
||||
|
||||
12
train.py
12
train.py
@ -20,15 +20,15 @@ import torch, torch.distributed as dist
|
||||
from torch.optim import AdamW
|
||||
from transformers import AutoConfig
|
||||
import numpy as np
|
||||
from src.parallel.tensor_parallel.tensor_parallel import TensorParallel
|
||||
import src.distributed.process_group_manager as pgm
|
||||
from picotron.parallel.tensor_parallel.tensor_parallel import TensorParallel
|
||||
import picotron.process_group_manager as pgm
|
||||
from utils import MicroBatchDataLoader, set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
|
||||
from src.distributed.process_group_manager import setup_process_group_manager
|
||||
from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from src.parallel.data_parallel.data_parallel_bucket import DataParallel
|
||||
from picotron.process_group_manager import setup_process_group_manager
|
||||
from picotron.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from picotron.parallel.data_parallel.data_parallel_bucket import DataParallel
|
||||
from model import Llama
|
||||
import wandb
|
||||
from src.distributed.distributed_primtives import all_reduce_loss_across_dp_cp_ranks
|
||||
from picotron.distributed.distributed_primtives import all_reduce_loss_across_dp_cp_ranks
|
||||
|
||||
def train_step(model, data_loader, device):
|
||||
acc_loss = 0.0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user