various fix (modeling, dataloader, cpu load)

This commit is contained in:
ferdinand.mom 2024-10-18 14:33:46 +00:00
parent 81726dfffe
commit 0b1d02a402
5 changed files with 132 additions and 73 deletions

View File

@ -116,11 +116,7 @@ if __name__ == "__main__":
model_hf = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)
model = Llama(
config=model_hf.config,
device=device,
)
model = Llama(config=model_hf.config, device=device)
picotron_to_hf = get_weights_mapping(model_hf, to_hf=True)
ref_state_dict = model_hf.state_dict()
@ -137,10 +133,7 @@ if __name__ == "__main__":
torch.save(model.state_dict(), args.save_path)
new_model = Llama(
config=model_hf.config,
device=device,
)
new_model = Llama(config=model_hf.config, device=device)
new_model.load_state_dict(torch.load(args.save_path))
print("Sanity check weight ...")

View File

@ -13,7 +13,7 @@ from model import Llama
def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
if pgm.process_group_manager.pp_world_size == 1:
return model.forward(input_ids=batch["input_ids"], position_ids=batch["position_index"])
return model.forward(input_ids=batch["input_ids"], position_ids=batch["position_ids"], hidden_states=batch["hidden_states"])
batch_size = batch["input_ids"].shape[0]
seq_len = batch["input_ids"].shape[1]
@ -28,7 +28,7 @@ def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
batch["hidden_states"] = None if pgm.process_group_manager.pp_is_first_stage else recv_buffer
output_tensor = model.forward(batch, device)
output_tensor = model.forward(input_ids=batch["input_ids"], position_ids=batch["position_ids"], hidden_states=batch["hidden_states"])
# Send output to the next stage.
pipeline_communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device)
@ -57,17 +57,18 @@ if __name__ == "__main__":
setup_process_group_manager(tp_size=1, pp_size=args.pp_size, dp_size=1, cp_size=1)
set_all_seed(seed=42)
#TODO: find a better way (should need to specify model_name + path to .pth)
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
config = AutoConfig.from_pretrained(model_name)
load2name = {
"smollm.pth": "HuggingFaceTB/SmolLM-360M-Instruct",
"llama1b.pth": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"llama3-B.pth": "meta-llama/Meta-Llama-3-8B",
}
base_model = Llama(
config=config,
device=device,
)
config = AutoConfig.from_pretrained(load2name[args.load_path])
base_model.load_state_dict(torch.load(args.load_path))
base_model = Llama(config=config, device=device)
base_model.load_state_dict(torch.load(args.load_path, map_location="cpu"))
model = PipelineParallel(base_model, config).to(device)
del base_model
model.eval()
@ -78,23 +79,23 @@ if __name__ == "__main__":
"What is your favorite color?",
]
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(load2name[args.load_path])
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenized_prompts = tokenizer(prompts, return_tensors="pt", padding=True).to(device=device)
tokenized_prompts = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
for _ in range(args.max_tokens):
# Create the batch
seq_len = tokenized_prompts["input_ids"].shape[1]
position_index = torch.arange(seq_len).view(1, -1).to(device=device)
position_ids = torch.arange(seq_len).view(1, -1)
batch_prompts = {
"input_ids": tokenized_prompts["input_ids"],
"input_ids": tokenized_prompts["input_ids"].to(device=device),
"target_ids": None,
"position_index": position_index,
"attn_mask": tokenized_prompts["attention_mask"].to(dtype=torch.bool),
"position_ids": position_ids.to(device=device),
"attn_mask": tokenized_prompts["attention_mask"].to(dtype=torch.bool, device=device),
"hidden_states": None,
}

View File

@ -4,7 +4,7 @@ from torch.nn import functional as F
from einops import rearrange
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

View File

@ -3,12 +3,9 @@ from distributed.distributed_primtives import pipeline_communicate, bidirectiona
import torch, torch.nn as nn, torch.nn.functional as F
import torch.distributed as dist
from parallel.base_parallel import BaseParallel
class PipelineParallel(BaseParallel):
class PipelineParallel(nn.Module):
def __init__(self, model, config):
super().__init__(model, config)
#TODO(fmom): find a better model to distributed layers without instantiating a base_model first
super().__init__()
layer_distribution = self.distribute_layers(config.num_hidden_layers)
self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity()
self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in layer_distribution})
@ -20,11 +17,11 @@ class PipelineParallel(BaseParallel):
start_layer = sum(layers_per_gpu[:pgm.process_group_manager.pp_rank])
return list(range(start_layer, start_layer + layers_per_gpu[pgm.process_group_manager.pp_rank]))
def forward(self, batch, device):
x = batch["hidden_states"].to(device) if batch["hidden_states"] is not None else batch["input_ids"].to(device)
def forward(self, input_ids, position_ids, hidden_states):
x = hidden_states if hidden_states is not None else input_ids
x = self.embedding(x)
for layer in self.decoder_layers.values():
x = layer(x, position_ids=batch["position_index"].to(device))
x = layer(x, position_ids=position_ids)
x = self.final_norm(x)
return self.final_proj(x)
@ -41,9 +38,9 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
for _ in range(data_loader.num_local_micro_batches): # All forward passes
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
batch = next(iter(data_loader))
batch["hidden_states"] = input_tensor
output_tensor = model.forward(batch, device)
batch = next(data_loader)
batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor
output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
@ -69,9 +66,9 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
logging_loss, input_tensors, output_tensors = 0.0, [], []
def _forward_step(input_tensor):
batch = next(iter(data_loader))
batch["hidden_states"] = input_tensor
output_tensor = model.forward(batch, device)
batch = next(data_loader)
batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor
output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')

134
train.py
View File

@ -1,12 +1,13 @@
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
import os
import numpy as np
import torch.nn.functional as F
import torch, torch.distributed as dist
from torch.optim import AdamW
from transformers import AutoConfig
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, DistributedSampler
from datasets import load_dataset
from datasets import load_dataset,Features, Sequence, Value
import argparse
import distributed.process_group_manager as pgm
@ -18,57 +19,121 @@ from parallel.data_parallel import DataParallel
from parallel.context_parallel import ContextParallel
from model import Llama
import wandb
import multiprocessing
class MicroBatchDataLoader(DataLoader):
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None):
self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length
self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc=1, split="train", num_samples=None):
self.global_batch_size = global_batch_size
self.micro_batch_size = micro_batch_size
self.seq_length = seq_length
self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size # each DP rank gets a local batch
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
self.grad_acc = grad_acc
self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.dataset = load_dataset(dataset_name, split=split)
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
if num_samples:
self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
dist.barrier()
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names, num_proc=multiprocessing.cpu_count()).with_format("torch", columns=["input_ids"])
self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False)
# Tokenize and chunk the dataset
self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_length, num_proc)
super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False)
self.sampler = DistributedSampler(
self.tokenized_dataset,
num_replicas=pgm.process_group_manager.dp_world_size,
rank=pgm.process_group_manager.dp_rank,
shuffle=False
)
super().__init__(
self.tokenized_dataset,
batch_size=micro_batch_size if pgm.process_group_manager.pp_world_size > 1 else self.local_batch_size, # in PP we split a single batch into multiple micro-batches
collate_fn=self.collate_batch,
pin_memory=True,
num_workers=num_workers,
sampler=self.sampler,
shuffle=False
)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc):
def _tokenizer_group_text(texts):
tokenized_text_batch = self.tokenizer.batch_encode_plus(
texts,
return_attention_mask=False,
return_token_type_ids=False,
return_tensors='np'
)
concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])}
total_length = len(concatenated_tokens['input_ids'])
if total_length >= sequence_length + 1:
total_length = ((total_length - 1) // sequence_length) * sequence_length + 1
result = {
'input_ids': [
concatenated_tokens['input_ids'][i : i + sequence_length + 1]
for i in range(0, total_length - sequence_length, sequence_length)
]
}
return result
def collate_batch(self, batch_data):
batch_input_ids = torch.stack([item['input_ids'] for item in batch_data])
batch_size, seq_len = batch_input_ids.shape
tokenized_dataset = dataset.map(
_tokenizer_group_text,
input_columns=text_column_name,
remove_columns=dataset.column_names,
features=Features({"input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1)}),
batched=True,
num_proc=num_proc, # Adjust this based on your system capabilities
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {sequence_length+1}",
)
return tokenized_dataset
def collate_batch(self, batch):
batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch])
batch_size = batch_input_ids.size(0)
start_idx = pgm.process_group_manager.cp_rank * self.seq_length_per_gpu
end_idx = start_idx + self.seq_length_per_gpu
input_ids = batch_input_ids[:, start_idx:end_idx].contiguous()
target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous()
position_index = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous()
position_ids = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous()
local_attn_mask = torch.tril(torch.ones((self.seq_length_per_gpu, self.seq_length_per_gpu), dtype=torch.bool))
attn_mask = local_attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous()
return {
"input_ids": input_ids,
"target_ids": target_ids,
"position_index": position_index,
"position_ids": position_ids,
"attn_mask": attn_mask,
"hidden_states": None
}
def __iter__(self):
if self._iterator is None:
self._iterator = super().__iter__()
return self
def __next__(self):
if self._iterator is None:
self._iterator = super().__iter__()
try:
batch = next(self._iterator)
except StopIteration:
self._iterator = None
raise StopIteration
return batch
def train_step(model, data_loader, device):
total_loss = 0.0
for _ in range(data_loader.num_local_micro_batches):
batch = next(iter(data_loader))
batch = next(data_loader)
input_ids = batch["input_ids"].to(device)
position_ids = batch["position_index"].to(device)
position_ids = batch["position_ids"].to(device)
target_ids = batch["target_ids"].to(device)
batch_size, seq_len = input_ids.shape
@ -94,6 +159,7 @@ if __name__ == "__main__":
parser.add_argument("--use_cpu", action="store_true", default=False)
parser.add_argument("--master_addr", type=str, default="localhost")
parser.add_argument("--master_port", type=int, default=29500)
parser.add_argument("--load_path", type=str, default="smollm.pth")
args = parser.parse_args()
@ -105,7 +171,7 @@ if __name__ == "__main__":
host = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"])
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 4, 1, 3e-4, int(1e4), 1e6, 42
assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
@ -125,9 +191,15 @@ if __name__ == "__main__":
# display_4D_parallelism_grid()
set_all_seed(SEED)
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
load2name = {
"smollm.pth": "HuggingFaceTB/SmolLM-360M-Instruct",
"llama1b.pth": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"llama3-B.pth": "meta-llama/Meta-Llama-3-8B",
}
dataset_name = "roneneldan/TinyStories"
config = AutoConfig.from_pretrained(model_name)
config = AutoConfig.from_pretrained(load2name[args.load_path])
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
wandb.init(
@ -137,7 +209,7 @@ if __name__ == "__main__":
"tensor_parallel_size": pgm.process_group_manager.tp_size,
"pipeline_parallel_size": pgm.process_group_manager.pp_size,
"data_parallel_size": pgm.process_group_manager.dp_size,
"model": model_name,
"model": load2name[args.load_path],
"dataset": dataset_name,
"max_tokens": MAX_TOKENS,
"learning_rate": LEARNING_RATE,
@ -147,16 +219,11 @@ if __name__ == "__main__":
},
)
#TODO: find a better way (should need to specify model_name + path to .pth)
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
config = AutoConfig.from_pretrained(model_name)
config = AutoConfig.from_pretrained(load2name[args.load_path])
model = Llama(
config=config,
device=device,
).to(device)
model = Llama(config=config, device=device)
model.load_state_dict(torch.load("smollm.pth"))
# model.load_state_dict(torch.load(args.load_path, map_location="cpu"))
# if pgm.process_group_manager.tp_world_size > 1:
# model = TensorParallel(model, config).to(device)
@ -172,7 +239,7 @@ if __name__ == "__main__":
model.train()
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES)
data_loader = MicroBatchDataLoader(global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=load2name[args.load_path], num_workers=4, num_proc=4, num_samples=NUM_SAMPLES)
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
@ -188,7 +255,8 @@ if __name__ == "__main__":
#TODO: add gradient accumulation
while trained_tokens < MAX_TOKENS:
data_loader.set_epoch(step)
#TODO: Add epoch support
# data_loader.set_epoch(step)
optimizer.zero_grad()