refactor organisation
This commit is contained in:
parent
47581d29e9
commit
3095ff4d4f
@ -10,7 +10,7 @@ from utils import set_all_seed
|
||||
import lovely_tensors as lt; lt.monkey_patch()
|
||||
|
||||
from model import Llama
|
||||
from process_group_manager import setup_process_group_manager
|
||||
from distributed.process_group_manager import setup_process_group_manager
|
||||
|
||||
def sanity_check_weights(model, model_hf, picotron_to_hf):
|
||||
|
||||
|
||||
0
convert_picotron_to_hf.py
Normal file
0
convert_picotron_to_hf.py
Normal file
32
dataset.py
32
dataset.py
@ -1,32 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from datasets import load_dataset
|
||||
|
||||
import process_group_manager as pgm
|
||||
|
||||
class MicroBatchDataLoader(DataLoader):
|
||||
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None):
|
||||
self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length
|
||||
self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size
|
||||
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
|
||||
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
self.dataset = load_dataset(dataset_name, split=split)
|
||||
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
|
||||
dist.barrier()
|
||||
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"])
|
||||
|
||||
self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False)
|
||||
|
||||
super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.sampler.set_epoch(epoch)
|
||||
|
||||
def collate_batch(self, batch_data):
|
||||
batch_input_ids = torch.stack([item['input_ids'] for item in batch_data])
|
||||
batch_size, seq_len = batch_input_ids.shape
|
||||
return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None}
|
||||
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import process_group_manager as pgm
|
||||
import distributed.process_group_manager as pgm
|
||||
import torch, torch.distributed as dist
|
||||
import process_group_manager as pgm
|
||||
import distributed.process_group_manager as pgm
|
||||
|
||||
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
|
||||
|
||||
10
generate.py
10
generate.py
@ -2,13 +2,13 @@
|
||||
import os
|
||||
import argparse
|
||||
import torch, torch.distributed as dist
|
||||
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM,AutoTokenizer
|
||||
from transformers import AutoTokenizer, AutoConfig, AutoTokenizer
|
||||
|
||||
from utils import set_all_seed
|
||||
import process_group_manager as pgm
|
||||
from process_group_manager import setup_process_group_manager
|
||||
from pipeline_parallel import PipelineParallel
|
||||
from distributed_primtives import communicate
|
||||
import distributed.process_group_manager as pgm
|
||||
from distributed.process_group_manager import setup_process_group_manager
|
||||
from parallel.pipeline_parallel import PipelineParallel
|
||||
from distributed.distributed_primtives import communicate
|
||||
from model import Llama
|
||||
|
||||
def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
|
||||
|
||||
@ -1,14 +1,12 @@
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import process_group_manager as pgm
|
||||
import distributed.process_group_manager as pgm
|
||||
|
||||
|
||||
class ContextParallel(nn.Module):
|
||||
def __init__(self, model, config):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.cp_world_size = pgm.process_group_manager.cp_world_size
|
||||
self.cp_rank = pgm.process_group_manager.cp_rank
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
@ -1,6 +1,6 @@
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import process_group_manager as pgm
|
||||
import distributed.process_group_manager as pgm
|
||||
|
||||
class DataParallel(nn.Module):
|
||||
def __init__(self, model, config):
|
||||
@ -1,5 +1,5 @@
|
||||
import process_group_manager as pgm
|
||||
from distributed_primtives import communicate, bidirectional_communicate
|
||||
import distributed.process_group_manager as pgm
|
||||
from distributed.distributed_primtives import communicate, bidirectional_communicate
|
||||
import torch, torch.nn as nn, torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
1
parallel/tensor_parallel.py
Normal file
1
parallel/tensor_parallel.py
Normal file
@ -0,0 +1 @@
|
||||
#TODO
|
||||
42
train.py
42
train.py
@ -3,20 +3,48 @@ import os
|
||||
import torch.nn.functional as F
|
||||
import torch, torch.distributed as dist
|
||||
from torch.optim import AdamW
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from datasets import load_dataset
|
||||
import argparse
|
||||
|
||||
import process_group_manager as pgm
|
||||
import distributed.process_group_manager as pgm
|
||||
from utils import set_all_seed, display_parallelism_grid, print
|
||||
from process_group_manager import setup_process_group_manager
|
||||
from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from data_parallel import DataParallel
|
||||
from context_parallel import ContextParallel
|
||||
from distributed.process_group_manager import setup_process_group_manager
|
||||
from parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from parallel.data_parallel import DataParallel
|
||||
from parallel.context_parallel import ContextParallel
|
||||
from model import Llama
|
||||
from dataset import MicroBatchDataLoader
|
||||
import wandb
|
||||
|
||||
class MicroBatchDataLoader(DataLoader):
|
||||
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None):
|
||||
self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length
|
||||
self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size
|
||||
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
|
||||
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
self.dataset = load_dataset(dataset_name, split=split)
|
||||
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
|
||||
dist.barrier()
|
||||
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"])
|
||||
|
||||
self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False)
|
||||
|
||||
super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.sampler.set_epoch(epoch)
|
||||
|
||||
def collate_batch(self, batch_data):
|
||||
batch_input_ids = torch.stack([item['input_ids'] for item in batch_data])
|
||||
batch_size, seq_len = batch_input_ids.shape
|
||||
return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None}
|
||||
|
||||
|
||||
def train_step(model, data_loader, device):
|
||||
total_loss = 0.0
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user