From 0b1d02a40233313244afd07c9455a59540ad2d22 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 18 Oct 2024 14:33:46 +0000 Subject: [PATCH] various fix (modeling, dataloader, cpu load) --- convert_hf_to_picotron.py | 11 +-- generate.py | 33 +++++---- model.py | 2 +- parallel/pipeline_parallel.py | 25 +++---- train.py | 134 +++++++++++++++++++++++++--------- 5 files changed, 132 insertions(+), 73 deletions(-) diff --git a/convert_hf_to_picotron.py b/convert_hf_to_picotron.py index d308f5e..2d64239 100644 --- a/convert_hf_to_picotron.py +++ b/convert_hf_to_picotron.py @@ -116,11 +116,7 @@ if __name__ == "__main__": model_hf = AutoModelForCausalLM.from_pretrained(args.model_name).to(device) - model = Llama( - config=model_hf.config, - device=device, - ) - + model = Llama(config=model_hf.config, device=device) picotron_to_hf = get_weights_mapping(model_hf, to_hf=True) ref_state_dict = model_hf.state_dict() @@ -137,10 +133,7 @@ if __name__ == "__main__": torch.save(model.state_dict(), args.save_path) - new_model = Llama( - config=model_hf.config, - device=device, - ) + new_model = Llama(config=model_hf.config, device=device) new_model.load_state_dict(torch.load(args.save_path)) print("Sanity check weight ...") diff --git a/generate.py b/generate.py index e34bf99..a589754 100644 --- a/generate.py +++ b/generate.py @@ -13,7 +13,7 @@ from model import Llama def run_one_inference_step(model, batch, device, config) -> torch.Tensor: if pgm.process_group_manager.pp_world_size == 1: - return model.forward(input_ids=batch["input_ids"], position_ids=batch["position_index"]) + return model.forward(input_ids=batch["input_ids"], position_ids=batch["position_ids"], hidden_states=batch["hidden_states"]) batch_size = batch["input_ids"].shape[0] seq_len = batch["input_ids"].shape[1] @@ -28,7 +28,7 @@ def run_one_inference_step(model, batch, device, config) -> torch.Tensor: batch["hidden_states"] = None if pgm.process_group_manager.pp_is_first_stage else recv_buffer - output_tensor = model.forward(batch, device) + output_tensor = model.forward(input_ids=batch["input_ids"], position_ids=batch["position_ids"], hidden_states=batch["hidden_states"]) # Send output to the next stage. pipeline_communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device) @@ -57,17 +57,18 @@ if __name__ == "__main__": setup_process_group_manager(tp_size=1, pp_size=args.pp_size, dp_size=1, cp_size=1) set_all_seed(seed=42) - #TODO: find a better way (should need to specify model_name + path to .pth) - model_name = "HuggingFaceTB/SmolLM-360M-Instruct" - config = AutoConfig.from_pretrained(model_name) + load2name = { + "smollm.pth": "HuggingFaceTB/SmolLM-360M-Instruct", + "llama1b.pth": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "llama3-B.pth": "meta-llama/Meta-Llama-3-8B", + } - base_model = Llama( - config=config, - device=device, - ) + config = AutoConfig.from_pretrained(load2name[args.load_path]) - base_model.load_state_dict(torch.load(args.load_path)) + base_model = Llama(config=config, device=device) + base_model.load_state_dict(torch.load(args.load_path, map_location="cpu")) model = PipelineParallel(base_model, config).to(device) + del base_model model.eval() @@ -78,23 +79,23 @@ if __name__ == "__main__": "What is your favorite color?", ] - tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(load2name[args.load_path]) tokenizer.padding_side = "left" tokenizer.pad_token = tokenizer.eos_token - tokenized_prompts = tokenizer(prompts, return_tensors="pt", padding=True).to(device=device) + tokenized_prompts = tokenizer(prompts, return_tensors="pt", padding=True).to(device) for _ in range(args.max_tokens): # Create the batch seq_len = tokenized_prompts["input_ids"].shape[1] - position_index = torch.arange(seq_len).view(1, -1).to(device=device) + position_ids = torch.arange(seq_len).view(1, -1) batch_prompts = { - "input_ids": tokenized_prompts["input_ids"], + "input_ids": tokenized_prompts["input_ids"].to(device=device), "target_ids": None, - "position_index": position_index, - "attn_mask": tokenized_prompts["attention_mask"].to(dtype=torch.bool), + "position_ids": position_ids.to(device=device), + "attn_mask": tokenized_prompts["attention_mask"].to(dtype=torch.bool, device=device), "hidden_states": None, } diff --git a/model.py b/model.py index fc54862..a79b1a5 100644 --- a/model.py +++ b/model.py @@ -4,7 +4,7 @@ from torch.nn import functional as F from einops import rearrange class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps diff --git a/parallel/pipeline_parallel.py b/parallel/pipeline_parallel.py index 9282d7b..5e2d29a 100644 --- a/parallel/pipeline_parallel.py +++ b/parallel/pipeline_parallel.py @@ -3,12 +3,9 @@ from distributed.distributed_primtives import pipeline_communicate, bidirectiona import torch, torch.nn as nn, torch.nn.functional as F import torch.distributed as dist -from parallel.base_parallel import BaseParallel - -class PipelineParallel(BaseParallel): +class PipelineParallel(nn.Module): def __init__(self, model, config): - super().__init__(model, config) - #TODO(fmom): find a better model to distributed layers without instantiating a base_model first + super().__init__() layer_distribution = self.distribute_layers(config.num_hidden_layers) self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity() self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in layer_distribution}) @@ -20,11 +17,11 @@ class PipelineParallel(BaseParallel): start_layer = sum(layers_per_gpu[:pgm.process_group_manager.pp_rank]) return list(range(start_layer, start_layer + layers_per_gpu[pgm.process_group_manager.pp_rank])) - def forward(self, batch, device): - x = batch["hidden_states"].to(device) if batch["hidden_states"] is not None else batch["input_ids"].to(device) + def forward(self, input_ids, position_ids, hidden_states): + x = hidden_states if hidden_states is not None else input_ids x = self.embedding(x) for layer in self.decoder_layers.values(): - x = layer(x, position_ids=batch["position_index"].to(device)) + x = layer(x, position_ids=position_ids) x = self.final_norm(x) return self.final_proj(x) @@ -41,9 +38,9 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device): for _ in range(data_loader.num_local_micro_batches): # All forward passes input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) - batch = next(iter(data_loader)) - batch["hidden_states"] = input_tensor - output_tensor = model.forward(batch, device) + batch = next(data_loader) + batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor + output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"]) pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32) # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough @@ -69,9 +66,9 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): logging_loss, input_tensors, output_tensors = 0.0, [], [] def _forward_step(input_tensor): - batch = next(iter(data_loader)) - batch["hidden_states"] = input_tensor - output_tensor = model.forward(batch, device) + batch = next(data_loader) + batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor + output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"]) # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') diff --git a/train.py b/train.py index 4a7fca4..caac752 100644 --- a/train.py +++ b/train.py @@ -1,12 +1,13 @@ #VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 import os +import numpy as np import torch.nn.functional as F import torch, torch.distributed as dist from torch.optim import AdamW from transformers import AutoConfig from transformers import AutoTokenizer from torch.utils.data import DataLoader, DistributedSampler -from datasets import load_dataset +from datasets import load_dataset,Features, Sequence, Value import argparse import distributed.process_group_manager as pgm @@ -18,57 +19,121 @@ from parallel.data_parallel import DataParallel from parallel.context_parallel import ContextParallel from model import Llama import wandb -import multiprocessing class MicroBatchDataLoader(DataLoader): - def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None): - self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length - self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size + def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc=1, split="train", num_samples=None): + self.global_batch_size = global_batch_size + self.micro_batch_size = micro_batch_size + self.seq_length = seq_length + self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size # each DP rank gets a local batch self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size + self.grad_acc = grad_acc self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) self.dataset = load_dataset(dataset_name, split=split) - if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) + if num_samples: + self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) + dist.barrier() - self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names, num_proc=multiprocessing.cpu_count()).with_format("torch", columns=["input_ids"]) - self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False) + # Tokenize and chunk the dataset + self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_length, num_proc) - super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False) + self.sampler = DistributedSampler( + self.tokenized_dataset, + num_replicas=pgm.process_group_manager.dp_world_size, + rank=pgm.process_group_manager.dp_rank, + shuffle=False + ) + + super().__init__( + self.tokenized_dataset, + batch_size=micro_batch_size if pgm.process_group_manager.pp_world_size > 1 else self.local_batch_size, # in PP we split a single batch into multiple micro-batches + collate_fn=self.collate_batch, + pin_memory=True, + num_workers=num_workers, + sampler=self.sampler, + shuffle=False + ) - def set_epoch(self, epoch): - self.sampler.set_epoch(epoch) + def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc): + def _tokenizer_group_text(texts): + tokenized_text_batch = self.tokenizer.batch_encode_plus( + texts, + return_attention_mask=False, + return_token_type_ids=False, + return_tensors='np' + ) + concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])} + total_length = len(concatenated_tokens['input_ids']) + if total_length >= sequence_length + 1: + total_length = ((total_length - 1) // sequence_length) * sequence_length + 1 + result = { + 'input_ids': [ + concatenated_tokens['input_ids'][i : i + sequence_length + 1] + for i in range(0, total_length - sequence_length, sequence_length) + ] + } + return result - def collate_batch(self, batch_data): - batch_input_ids = torch.stack([item['input_ids'] for item in batch_data]) - batch_size, seq_len = batch_input_ids.shape + tokenized_dataset = dataset.map( + _tokenizer_group_text, + input_columns=text_column_name, + remove_columns=dataset.column_names, + features=Features({"input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1)}), + batched=True, + num_proc=num_proc, # Adjust this based on your system capabilities + load_from_cache_file=True, + desc=f"Grouping texts in chunks of {sequence_length+1}", + ) + + return tokenized_dataset + + def collate_batch(self, batch): + batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch]) + batch_size = batch_input_ids.size(0) start_idx = pgm.process_group_manager.cp_rank * self.seq_length_per_gpu end_idx = start_idx + self.seq_length_per_gpu input_ids = batch_input_ids[:, start_idx:end_idx].contiguous() target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous() - position_index = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() + position_ids = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() local_attn_mask = torch.tril(torch.ones((self.seq_length_per_gpu, self.seq_length_per_gpu), dtype=torch.bool)) attn_mask = local_attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() return { "input_ids": input_ids, "target_ids": target_ids, - "position_index": position_index, + "position_ids": position_ids, "attn_mask": attn_mask, "hidden_states": None } + def __iter__(self): + if self._iterator is None: + self._iterator = super().__iter__() + return self + + def __next__(self): + if self._iterator is None: + self._iterator = super().__iter__() + try: + batch = next(self._iterator) + except StopIteration: + self._iterator = None + raise StopIteration + return batch + def train_step(model, data_loader, device): total_loss = 0.0 for _ in range(data_loader.num_local_micro_batches): - batch = next(iter(data_loader)) - + batch = next(data_loader) + input_ids = batch["input_ids"].to(device) - position_ids = batch["position_index"].to(device) + position_ids = batch["position_ids"].to(device) target_ids = batch["target_ids"].to(device) batch_size, seq_len = input_ids.shape @@ -94,6 +159,7 @@ if __name__ == "__main__": parser.add_argument("--use_cpu", action="store_true", default=False) parser.add_argument("--master_addr", type=str, default="localhost") parser.add_argument("--master_port", type=int, default=29500) + parser.add_argument("--load_path", type=str, default="smollm.pth") args = parser.parse_args() @@ -105,7 +171,7 @@ if __name__ == "__main__": host = os.environ["MASTER_ADDR"] port = int(os.environ["MASTER_PORT"]) - SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42 + SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 4, 1, 3e-4, int(1e4), 1e6, 42 assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism" @@ -125,9 +191,15 @@ if __name__ == "__main__": # display_4D_parallelism_grid() set_all_seed(SEED) - model_name = "HuggingFaceTB/SmolLM-360M-Instruct" + + load2name = { + "smollm.pth": "HuggingFaceTB/SmolLM-360M-Instruct", + "llama1b.pth": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "llama3-B.pth": "meta-llama/Meta-Llama-3-8B", + } + dataset_name = "roneneldan/TinyStories" - config = AutoConfig.from_pretrained(model_name) + config = AutoConfig.from_pretrained(load2name[args.load_path]) if pgm.process_group_manager.global_rank == 0 and args.use_wandb: wandb.init( @@ -137,7 +209,7 @@ if __name__ == "__main__": "tensor_parallel_size": pgm.process_group_manager.tp_size, "pipeline_parallel_size": pgm.process_group_manager.pp_size, "data_parallel_size": pgm.process_group_manager.dp_size, - "model": model_name, + "model": load2name[args.load_path], "dataset": dataset_name, "max_tokens": MAX_TOKENS, "learning_rate": LEARNING_RATE, @@ -147,16 +219,11 @@ if __name__ == "__main__": }, ) - #TODO: find a better way (should need to specify model_name + path to .pth) - model_name = "HuggingFaceTB/SmolLM-360M-Instruct" - config = AutoConfig.from_pretrained(model_name) + config = AutoConfig.from_pretrained(load2name[args.load_path]) - model = Llama( - config=config, - device=device, - ).to(device) + model = Llama(config=config, device=device) - model.load_state_dict(torch.load("smollm.pth")) + # model.load_state_dict(torch.load(args.load_path, map_location="cpu")) # if pgm.process_group_manager.tp_world_size > 1: # model = TensorParallel(model, config).to(device) @@ -172,7 +239,7 @@ if __name__ == "__main__": model.train() - data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES) + data_loader = MicroBatchDataLoader(global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=load2name[args.load_path], num_workers=4, num_proc=4, num_samples=NUM_SAMPLES) tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size) optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) @@ -188,7 +255,8 @@ if __name__ == "__main__": #TODO: add gradient accumulation while trained_tokens < MAX_TOKENS: - data_loader.set_epoch(step) + #TODO: Add epoch support + # data_loader.set_epoch(step) optimizer.zero_grad()