From 3095ff4d4f0bbb779f0a661308e47df998e5a79d Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 10 Oct 2024 15:12:14 +0000 Subject: [PATCH] refactor organisation --- convert_hf_to_picotron.py | 2 +- convert_picotron_to_hf.py | 0 dataset.py | 32 -------------- .../distributed_primtives.py | 4 +- .../process_group_manager.py | 0 generate.py | 10 ++--- .../context_parallel.py | 4 +- data_parallel.py => parallel/data_parallel.py | 2 +- .../pipeline_parallel.py | 4 +- parallel/tensor_parallel.py | 1 + train.py | 42 +++++++++++++++---- utils.py | 2 +- 12 files changed, 49 insertions(+), 54 deletions(-) create mode 100644 convert_picotron_to_hf.py delete mode 100644 dataset.py rename distributed_primtives.py => distributed/distributed_primtives.py (96%) rename process_group_manager.py => distributed/process_group_manager.py (100%) rename context_parallel.py => parallel/context_parallel.py (72%) rename data_parallel.py => parallel/data_parallel.py (90%) rename pipeline_parallel.py => parallel/pipeline_parallel.py (98%) create mode 100644 parallel/tensor_parallel.py diff --git a/convert_hf_to_picotron.py b/convert_hf_to_picotron.py index 0fb9ee9..d308f5e 100644 --- a/convert_hf_to_picotron.py +++ b/convert_hf_to_picotron.py @@ -10,7 +10,7 @@ from utils import set_all_seed import lovely_tensors as lt; lt.monkey_patch() from model import Llama -from process_group_manager import setup_process_group_manager +from distributed.process_group_manager import setup_process_group_manager def sanity_check_weights(model, model_hf, picotron_to_hf): diff --git a/convert_picotron_to_hf.py b/convert_picotron_to_hf.py new file mode 100644 index 0000000..e69de29 diff --git a/dataset.py b/dataset.py deleted file mode 100644 index cac809c..0000000 --- a/dataset.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -import torch.distributed as dist -from transformers import AutoTokenizer -from torch.utils.data import DataLoader, DistributedSampler -from datasets import load_dataset - -import process_group_manager as pgm - -class MicroBatchDataLoader(DataLoader): - def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None): - self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length - self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size - self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size - self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size - - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - self.dataset = load_dataset(dataset_name, split=split) - if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) - dist.barrier() - self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"]) - - self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False) - - super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False) - - def set_epoch(self, epoch): - self.sampler.set_epoch(epoch) - - def collate_batch(self, batch_data): - batch_input_ids = torch.stack([item['input_ids'] for item in batch_data]) - batch_size, seq_len = batch_input_ids.shape - return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None} diff --git a/distributed_primtives.py b/distributed/distributed_primtives.py similarity index 96% rename from distributed_primtives.py rename to distributed/distributed_primtives.py index 824f735..dd65185 100644 --- a/distributed_primtives.py +++ b/distributed/distributed_primtives.py @@ -1,7 +1,7 @@ import os -import process_group_manager as pgm +import distributed.process_group_manager as pgm import torch, torch.distributed as dist -import process_group_manager as pgm +import distributed.process_group_manager as pgm STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1" diff --git a/process_group_manager.py b/distributed/process_group_manager.py similarity index 100% rename from process_group_manager.py rename to distributed/process_group_manager.py diff --git a/generate.py b/generate.py index bb6ac4f..33c5348 100644 --- a/generate.py +++ b/generate.py @@ -2,13 +2,13 @@ import os import argparse import torch, torch.distributed as dist -from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM,AutoTokenizer +from transformers import AutoTokenizer, AutoConfig, AutoTokenizer from utils import set_all_seed -import process_group_manager as pgm -from process_group_manager import setup_process_group_manager -from pipeline_parallel import PipelineParallel -from distributed_primtives import communicate +import distributed.process_group_manager as pgm +from distributed.process_group_manager import setup_process_group_manager +from parallel.pipeline_parallel import PipelineParallel +from distributed.distributed_primtives import communicate from model import Llama def run_one_inference_step(model, batch, device, config) -> torch.Tensor: diff --git a/context_parallel.py b/parallel/context_parallel.py similarity index 72% rename from context_parallel.py rename to parallel/context_parallel.py index 85e45b8..b4b5da8 100644 --- a/context_parallel.py +++ b/parallel/context_parallel.py @@ -1,14 +1,12 @@ import torch.distributed as dist import torch.nn as nn -import process_group_manager as pgm +import distributed.process_group_manager as pgm class ContextParallel(nn.Module): def __init__(self, model, config): super().__init__() self.model = model - self.cp_world_size = pgm.process_group_manager.cp_world_size - self.cp_rank = pgm.process_group_manager.cp_rank def forward(self, *args, **kwargs): return self.model(*args, **kwargs) diff --git a/data_parallel.py b/parallel/data_parallel.py similarity index 90% rename from data_parallel.py rename to parallel/data_parallel.py index 104b01a..9148876 100644 --- a/data_parallel.py +++ b/parallel/data_parallel.py @@ -1,6 +1,6 @@ import torch.distributed as dist import torch.nn as nn -import process_group_manager as pgm +import distributed.process_group_manager as pgm class DataParallel(nn.Module): def __init__(self, model, config): diff --git a/pipeline_parallel.py b/parallel/pipeline_parallel.py similarity index 98% rename from pipeline_parallel.py rename to parallel/pipeline_parallel.py index 1ac3005..3f92c54 100644 --- a/pipeline_parallel.py +++ b/parallel/pipeline_parallel.py @@ -1,5 +1,5 @@ -import process_group_manager as pgm -from distributed_primtives import communicate, bidirectional_communicate +import distributed.process_group_manager as pgm +from distributed.distributed_primtives import communicate, bidirectional_communicate import torch, torch.nn as nn, torch.nn.functional as F import torch.distributed as dist diff --git a/parallel/tensor_parallel.py b/parallel/tensor_parallel.py new file mode 100644 index 0000000..503fa1d --- /dev/null +++ b/parallel/tensor_parallel.py @@ -0,0 +1 @@ +#TODO \ No newline at end of file diff --git a/train.py b/train.py index 7c7bb59..5a52e0a 100644 --- a/train.py +++ b/train.py @@ -3,20 +3,48 @@ import os import torch.nn.functional as F import torch, torch.distributed as dist from torch.optim import AdamW -from transformers import AutoConfig, AutoModelForCausalLM - +from transformers import AutoConfig +from transformers import AutoTokenizer +from torch.utils.data import DataLoader, DistributedSampler +from datasets import load_dataset import argparse -import process_group_manager as pgm +import distributed.process_group_manager as pgm from utils import set_all_seed, display_parallelism_grid, print -from process_group_manager import setup_process_group_manager -from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel -from data_parallel import DataParallel -from context_parallel import ContextParallel +from distributed.process_group_manager import setup_process_group_manager +from parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel +from parallel.data_parallel import DataParallel +from parallel.context_parallel import ContextParallel from model import Llama from dataset import MicroBatchDataLoader import wandb +class MicroBatchDataLoader(DataLoader): + def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None): + self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length + self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size + self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size + self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + self.dataset = load_dataset(dataset_name, split=split) + if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) + dist.barrier() + self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"]) + + self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False) + + super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def collate_batch(self, batch_data): + batch_input_ids = torch.stack([item['input_ids'] for item in batch_data]) + batch_size, seq_len = batch_input_ids.shape + return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None} + + def train_step(model, data_loader, device): total_loss = 0.0 diff --git a/utils.py b/utils.py index 9320b02..ec12408 100644 --- a/utils.py +++ b/utils.py @@ -3,7 +3,7 @@ import random import numpy as np import builtins import fcntl -import process_group_manager as pgm +import distributed.process_group_manager as pgm def print(*args, **kwargs): """ solves multi-process interleaved print problem """