various fix (modeling, dataloader, cpu load)

This commit is contained in:
ferdinand.mom 2024-10-18 14:33:46 +00:00
parent 81726dfffe
commit 0b1d02a402
5 changed files with 132 additions and 73 deletions

View File

@ -116,11 +116,7 @@ if __name__ == "__main__":
model_hf = AutoModelForCausalLM.from_pretrained(args.model_name).to(device) model_hf = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)
model = Llama( model = Llama(config=model_hf.config, device=device)
config=model_hf.config,
device=device,
)
picotron_to_hf = get_weights_mapping(model_hf, to_hf=True) picotron_to_hf = get_weights_mapping(model_hf, to_hf=True)
ref_state_dict = model_hf.state_dict() ref_state_dict = model_hf.state_dict()
@ -137,10 +133,7 @@ if __name__ == "__main__":
torch.save(model.state_dict(), args.save_path) torch.save(model.state_dict(), args.save_path)
new_model = Llama( new_model = Llama(config=model_hf.config, device=device)
config=model_hf.config,
device=device,
)
new_model.load_state_dict(torch.load(args.save_path)) new_model.load_state_dict(torch.load(args.save_path))
print("Sanity check weight ...") print("Sanity check weight ...")

View File

@ -13,7 +13,7 @@ from model import Llama
def run_one_inference_step(model, batch, device, config) -> torch.Tensor: def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
if pgm.process_group_manager.pp_world_size == 1: if pgm.process_group_manager.pp_world_size == 1:
return model.forward(input_ids=batch["input_ids"], position_ids=batch["position_index"]) return model.forward(input_ids=batch["input_ids"], position_ids=batch["position_ids"], hidden_states=batch["hidden_states"])
batch_size = batch["input_ids"].shape[0] batch_size = batch["input_ids"].shape[0]
seq_len = batch["input_ids"].shape[1] seq_len = batch["input_ids"].shape[1]
@ -28,7 +28,7 @@ def run_one_inference_step(model, batch, device, config) -> torch.Tensor:
batch["hidden_states"] = None if pgm.process_group_manager.pp_is_first_stage else recv_buffer batch["hidden_states"] = None if pgm.process_group_manager.pp_is_first_stage else recv_buffer
output_tensor = model.forward(batch, device) output_tensor = model.forward(input_ids=batch["input_ids"], position_ids=batch["position_ids"], hidden_states=batch["hidden_states"])
# Send output to the next stage. # Send output to the next stage.
pipeline_communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device) pipeline_communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device)
@ -57,17 +57,18 @@ if __name__ == "__main__":
setup_process_group_manager(tp_size=1, pp_size=args.pp_size, dp_size=1, cp_size=1) setup_process_group_manager(tp_size=1, pp_size=args.pp_size, dp_size=1, cp_size=1)
set_all_seed(seed=42) set_all_seed(seed=42)
#TODO: find a better way (should need to specify model_name + path to .pth) load2name = {
model_name = "HuggingFaceTB/SmolLM-360M-Instruct" "smollm.pth": "HuggingFaceTB/SmolLM-360M-Instruct",
config = AutoConfig.from_pretrained(model_name) "llama1b.pth": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"llama3-B.pth": "meta-llama/Meta-Llama-3-8B",
}
base_model = Llama( config = AutoConfig.from_pretrained(load2name[args.load_path])
config=config,
device=device,
)
base_model.load_state_dict(torch.load(args.load_path)) base_model = Llama(config=config, device=device)
base_model.load_state_dict(torch.load(args.load_path, map_location="cpu"))
model = PipelineParallel(base_model, config).to(device) model = PipelineParallel(base_model, config).to(device)
del base_model del base_model
model.eval() model.eval()
@ -78,23 +79,23 @@ if __name__ == "__main__":
"What is your favorite color?", "What is your favorite color?",
] ]
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(load2name[args.load_path])
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
tokenized_prompts = tokenizer(prompts, return_tensors="pt", padding=True).to(device=device) tokenized_prompts = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
for _ in range(args.max_tokens): for _ in range(args.max_tokens):
# Create the batch # Create the batch
seq_len = tokenized_prompts["input_ids"].shape[1] seq_len = tokenized_prompts["input_ids"].shape[1]
position_index = torch.arange(seq_len).view(1, -1).to(device=device) position_ids = torch.arange(seq_len).view(1, -1)
batch_prompts = { batch_prompts = {
"input_ids": tokenized_prompts["input_ids"], "input_ids": tokenized_prompts["input_ids"].to(device=device),
"target_ids": None, "target_ids": None,
"position_index": position_index, "position_ids": position_ids.to(device=device),
"attn_mask": tokenized_prompts["attention_mask"].to(dtype=torch.bool), "attn_mask": tokenized_prompts["attention_mask"].to(dtype=torch.bool, device=device),
"hidden_states": None, "hidden_states": None,
} }

View File

@ -4,7 +4,7 @@ from torch.nn import functional as F
from einops import rearrange from einops import rearrange
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps

View File

@ -3,12 +3,9 @@ from distributed.distributed_primtives import pipeline_communicate, bidirectiona
import torch, torch.nn as nn, torch.nn.functional as F import torch, torch.nn as nn, torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
from parallel.base_parallel import BaseParallel class PipelineParallel(nn.Module):
class PipelineParallel(BaseParallel):
def __init__(self, model, config): def __init__(self, model, config):
super().__init__(model, config) super().__init__()
#TODO(fmom): find a better model to distributed layers without instantiating a base_model first
layer_distribution = self.distribute_layers(config.num_hidden_layers) layer_distribution = self.distribute_layers(config.num_hidden_layers)
self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity() self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity()
self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in layer_distribution}) self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in layer_distribution})
@ -20,11 +17,11 @@ class PipelineParallel(BaseParallel):
start_layer = sum(layers_per_gpu[:pgm.process_group_manager.pp_rank]) start_layer = sum(layers_per_gpu[:pgm.process_group_manager.pp_rank])
return list(range(start_layer, start_layer + layers_per_gpu[pgm.process_group_manager.pp_rank])) return list(range(start_layer, start_layer + layers_per_gpu[pgm.process_group_manager.pp_rank]))
def forward(self, batch, device): def forward(self, input_ids, position_ids, hidden_states):
x = batch["hidden_states"].to(device) if batch["hidden_states"] is not None else batch["input_ids"].to(device) x = hidden_states if hidden_states is not None else input_ids
x = self.embedding(x) x = self.embedding(x)
for layer in self.decoder_layers.values(): for layer in self.decoder_layers.values():
x = layer(x, position_ids=batch["position_index"].to(device)) x = layer(x, position_ids=position_ids)
x = self.final_norm(x) x = self.final_norm(x)
return self.final_proj(x) return self.final_proj(x)
@ -41,9 +38,9 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
for _ in range(data_loader.num_local_micro_batches): # All forward passes for _ in range(data_loader.num_local_micro_batches): # All forward passes
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
batch = next(iter(data_loader)) batch = next(data_loader)
batch["hidden_states"] = input_tensor batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor
output_tensor = model.forward(batch, device) output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32) pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
@ -69,9 +66,9 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
logging_loss, input_tensors, output_tensors = 0.0, [], [] logging_loss, input_tensors, output_tensors = 0.0, [], []
def _forward_step(input_tensor): def _forward_step(input_tensor):
batch = next(iter(data_loader)) batch = next(data_loader)
batch["hidden_states"] = input_tensor batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor
output_tensor = model.forward(batch, device) output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank: if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')

134
train.py
View File

@ -1,12 +1,13 @@
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 #VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
import os import os
import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
import torch, torch.distributed as dist import torch, torch.distributed as dist
from torch.optim import AdamW from torch.optim import AdamW
from transformers import AutoConfig from transformers import AutoConfig
from transformers import AutoTokenizer from transformers import AutoTokenizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from datasets import load_dataset from datasets import load_dataset,Features, Sequence, Value
import argparse import argparse
import distributed.process_group_manager as pgm import distributed.process_group_manager as pgm
@ -18,57 +19,121 @@ from parallel.data_parallel import DataParallel
from parallel.context_parallel import ContextParallel from parallel.context_parallel import ContextParallel
from model import Llama from model import Llama
import wandb import wandb
import multiprocessing
class MicroBatchDataLoader(DataLoader): class MicroBatchDataLoader(DataLoader):
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None): def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc=1, split="train", num_samples=None):
self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length self.global_batch_size = global_batch_size
self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size self.micro_batch_size = micro_batch_size
self.seq_length = seq_length
self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size # each DP rank gets a local batch
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
self.grad_acc = grad_acc
self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.dataset = load_dataset(dataset_name, split=split) self.dataset = load_dataset(dataset_name, split=split)
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) if num_samples:
self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
dist.barrier() dist.barrier()
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names, num_proc=multiprocessing.cpu_count()).with_format("torch", columns=["input_ids"])
self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False) # Tokenize and chunk the dataset
self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_length, num_proc)
super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False) self.sampler = DistributedSampler(
self.tokenized_dataset,
num_replicas=pgm.process_group_manager.dp_world_size,
rank=pgm.process_group_manager.dp_rank,
shuffle=False
)
super().__init__(
self.tokenized_dataset,
batch_size=micro_batch_size if pgm.process_group_manager.pp_world_size > 1 else self.local_batch_size, # in PP we split a single batch into multiple micro-batches
collate_fn=self.collate_batch,
pin_memory=True,
num_workers=num_workers,
sampler=self.sampler,
shuffle=False
)
def set_epoch(self, epoch): def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc):
self.sampler.set_epoch(epoch) def _tokenizer_group_text(texts):
tokenized_text_batch = self.tokenizer.batch_encode_plus(
texts,
return_attention_mask=False,
return_token_type_ids=False,
return_tensors='np'
)
concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])}
total_length = len(concatenated_tokens['input_ids'])
if total_length >= sequence_length + 1:
total_length = ((total_length - 1) // sequence_length) * sequence_length + 1
result = {
'input_ids': [
concatenated_tokens['input_ids'][i : i + sequence_length + 1]
for i in range(0, total_length - sequence_length, sequence_length)
]
}
return result
def collate_batch(self, batch_data): tokenized_dataset = dataset.map(
batch_input_ids = torch.stack([item['input_ids'] for item in batch_data]) _tokenizer_group_text,
batch_size, seq_len = batch_input_ids.shape input_columns=text_column_name,
remove_columns=dataset.column_names,
features=Features({"input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1)}),
batched=True,
num_proc=num_proc, # Adjust this based on your system capabilities
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {sequence_length+1}",
)
return tokenized_dataset
def collate_batch(self, batch):
batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch])
batch_size = batch_input_ids.size(0)
start_idx = pgm.process_group_manager.cp_rank * self.seq_length_per_gpu start_idx = pgm.process_group_manager.cp_rank * self.seq_length_per_gpu
end_idx = start_idx + self.seq_length_per_gpu end_idx = start_idx + self.seq_length_per_gpu
input_ids = batch_input_ids[:, start_idx:end_idx].contiguous() input_ids = batch_input_ids[:, start_idx:end_idx].contiguous()
target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous() target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous()
position_index = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() position_ids = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous()
local_attn_mask = torch.tril(torch.ones((self.seq_length_per_gpu, self.seq_length_per_gpu), dtype=torch.bool)) local_attn_mask = torch.tril(torch.ones((self.seq_length_per_gpu, self.seq_length_per_gpu), dtype=torch.bool))
attn_mask = local_attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() attn_mask = local_attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous()
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"target_ids": target_ids, "target_ids": target_ids,
"position_index": position_index, "position_ids": position_ids,
"attn_mask": attn_mask, "attn_mask": attn_mask,
"hidden_states": None "hidden_states": None
} }
def __iter__(self):
if self._iterator is None:
self._iterator = super().__iter__()
return self
def __next__(self):
if self._iterator is None:
self._iterator = super().__iter__()
try:
batch = next(self._iterator)
except StopIteration:
self._iterator = None
raise StopIteration
return batch
def train_step(model, data_loader, device): def train_step(model, data_loader, device):
total_loss = 0.0 total_loss = 0.0
for _ in range(data_loader.num_local_micro_batches): for _ in range(data_loader.num_local_micro_batches):
batch = next(iter(data_loader)) batch = next(data_loader)
input_ids = batch["input_ids"].to(device) input_ids = batch["input_ids"].to(device)
position_ids = batch["position_index"].to(device) position_ids = batch["position_ids"].to(device)
target_ids = batch["target_ids"].to(device) target_ids = batch["target_ids"].to(device)
batch_size, seq_len = input_ids.shape batch_size, seq_len = input_ids.shape
@ -94,6 +159,7 @@ if __name__ == "__main__":
parser.add_argument("--use_cpu", action="store_true", default=False) parser.add_argument("--use_cpu", action="store_true", default=False)
parser.add_argument("--master_addr", type=str, default="localhost") parser.add_argument("--master_addr", type=str, default="localhost")
parser.add_argument("--master_port", type=int, default=29500) parser.add_argument("--master_port", type=int, default=29500)
parser.add_argument("--load_path", type=str, default="smollm.pth")
args = parser.parse_args() args = parser.parse_args()
@ -105,7 +171,7 @@ if __name__ == "__main__":
host = os.environ["MASTER_ADDR"] host = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"]) port = int(os.environ["MASTER_PORT"])
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42 SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 4, 1, 3e-4, int(1e4), 1e6, 42
assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism" assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
@ -125,9 +191,15 @@ if __name__ == "__main__":
# display_4D_parallelism_grid() # display_4D_parallelism_grid()
set_all_seed(SEED) set_all_seed(SEED)
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
load2name = {
"smollm.pth": "HuggingFaceTB/SmolLM-360M-Instruct",
"llama1b.pth": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"llama3-B.pth": "meta-llama/Meta-Llama-3-8B",
}
dataset_name = "roneneldan/TinyStories" dataset_name = "roneneldan/TinyStories"
config = AutoConfig.from_pretrained(model_name) config = AutoConfig.from_pretrained(load2name[args.load_path])
if pgm.process_group_manager.global_rank == 0 and args.use_wandb: if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
wandb.init( wandb.init(
@ -137,7 +209,7 @@ if __name__ == "__main__":
"tensor_parallel_size": pgm.process_group_manager.tp_size, "tensor_parallel_size": pgm.process_group_manager.tp_size,
"pipeline_parallel_size": pgm.process_group_manager.pp_size, "pipeline_parallel_size": pgm.process_group_manager.pp_size,
"data_parallel_size": pgm.process_group_manager.dp_size, "data_parallel_size": pgm.process_group_manager.dp_size,
"model": model_name, "model": load2name[args.load_path],
"dataset": dataset_name, "dataset": dataset_name,
"max_tokens": MAX_TOKENS, "max_tokens": MAX_TOKENS,
"learning_rate": LEARNING_RATE, "learning_rate": LEARNING_RATE,
@ -147,16 +219,11 @@ if __name__ == "__main__":
}, },
) )
#TODO: find a better way (should need to specify model_name + path to .pth) config = AutoConfig.from_pretrained(load2name[args.load_path])
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
config = AutoConfig.from_pretrained(model_name)
model = Llama( model = Llama(config=config, device=device)
config=config,
device=device,
).to(device)
model.load_state_dict(torch.load("smollm.pth")) # model.load_state_dict(torch.load(args.load_path, map_location="cpu"))
# if pgm.process_group_manager.tp_world_size > 1: # if pgm.process_group_manager.tp_world_size > 1:
# model = TensorParallel(model, config).to(device) # model = TensorParallel(model, config).to(device)
@ -172,7 +239,7 @@ if __name__ == "__main__":
model.train() model.train()
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES) data_loader = MicroBatchDataLoader(global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=load2name[args.load_path], num_workers=4, num_proc=4, num_samples=NUM_SAMPLES)
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size) tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
@ -188,7 +255,8 @@ if __name__ == "__main__":
#TODO: add gradient accumulation #TODO: add gradient accumulation
while trained_tokens < MAX_TOKENS: while trained_tokens < MAX_TOKENS:
data_loader.set_epoch(step) #TODO: Add epoch support
# data_loader.set_epoch(step)
optimizer.zero_grad() optimizer.zero_grad()