refactor organisation

This commit is contained in:
ferdinand.mom 2024-10-10 15:12:14 +00:00
parent 47581d29e9
commit 3095ff4d4f
12 changed files with 49 additions and 54 deletions

View File

@ -10,7 +10,7 @@ from utils import set_all_seed
import lovely_tensors as lt; lt.monkey_patch()
from model import Llama
from process_group_manager import setup_process_group_manager
from distributed.process_group_manager import setup_process_group_manager
def sanity_check_weights(model, model_hf, picotron_to_hf):

View File

View File

@ -1,32 +0,0 @@
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, DistributedSampler
from datasets import load_dataset
import process_group_manager as pgm
class MicroBatchDataLoader(DataLoader):
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None):
self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length
self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.dataset = load_dataset(dataset_name, split=split)
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
dist.barrier()
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"])
self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False)
super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def collate_batch(self, batch_data):
batch_input_ids = torch.stack([item['input_ids'] for item in batch_data])
batch_size, seq_len = batch_input_ids.shape
return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None}

View File

@ -1,7 +1,7 @@
import os
import process_group_manager as pgm
import distributed.process_group_manager as pgm
import torch, torch.distributed as dist
import process_group_manager as pgm
import distributed.process_group_manager as pgm
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"

View File

@ -2,13 +2,13 @@
import os
import argparse
import torch, torch.distributed as dist
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM,AutoTokenizer
from transformers import AutoTokenizer, AutoConfig, AutoTokenizer
from utils import set_all_seed
import process_group_manager as pgm
from process_group_manager import setup_process_group_manager
from pipeline_parallel import PipelineParallel
from distributed_primtives import communicate
import distributed.process_group_manager as pgm
from distributed.process_group_manager import setup_process_group_manager
from parallel.pipeline_parallel import PipelineParallel
from distributed.distributed_primtives import communicate
from model import Llama
def run_one_inference_step(model, batch, device, config) -> torch.Tensor:

View File

@ -1,14 +1,12 @@
import torch.distributed as dist
import torch.nn as nn
import process_group_manager as pgm
import distributed.process_group_manager as pgm
class ContextParallel(nn.Module):
def __init__(self, model, config):
super().__init__()
self.model = model
self.cp_world_size = pgm.process_group_manager.cp_world_size
self.cp_rank = pgm.process_group_manager.cp_rank
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

View File

@ -1,6 +1,6 @@
import torch.distributed as dist
import torch.nn as nn
import process_group_manager as pgm
import distributed.process_group_manager as pgm
class DataParallel(nn.Module):
def __init__(self, model, config):

View File

@ -1,5 +1,5 @@
import process_group_manager as pgm
from distributed_primtives import communicate, bidirectional_communicate
import distributed.process_group_manager as pgm
from distributed.distributed_primtives import communicate, bidirectional_communicate
import torch, torch.nn as nn, torch.nn.functional as F
import torch.distributed as dist

View File

@ -0,0 +1 @@
#TODO

View File

@ -3,20 +3,48 @@ import os
import torch.nn.functional as F
import torch, torch.distributed as dist
from torch.optim import AdamW
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import AutoConfig
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, DistributedSampler
from datasets import load_dataset
import argparse
import process_group_manager as pgm
import distributed.process_group_manager as pgm
from utils import set_all_seed, display_parallelism_grid, print
from process_group_manager import setup_process_group_manager
from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from data_parallel import DataParallel
from context_parallel import ContextParallel
from distributed.process_group_manager import setup_process_group_manager
from parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from parallel.data_parallel import DataParallel
from parallel.context_parallel import ContextParallel
from model import Llama
from dataset import MicroBatchDataLoader
import wandb
class MicroBatchDataLoader(DataLoader):
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None):
self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length
self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.dataset = load_dataset(dataset_name, split=split)
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
dist.barrier()
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"])
self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False)
super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def collate_batch(self, batch_data):
batch_input_ids = torch.stack([item['input_ids'] for item in batch_data])
batch_size, seq_len = batch_input_ids.shape
return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None}
def train_step(model, data_loader, device):
total_loss = 0.0

View File

@ -3,7 +3,7 @@ import random
import numpy as np
import builtins
import fcntl
import process_group_manager as pgm
import distributed.process_group_manager as pgm
def print(*args, **kwargs):
""" solves multi-process interleaved print problem """