add some logs, refactor dataloader

This commit is contained in:
zzhhjjj 2024-10-23 00:38:27 +00:00
parent ec1e1e5ccf
commit 63307c79a1
2 changed files with 61 additions and 29 deletions

View File

@ -1,6 +1,6 @@
"""Training script for LLaMA model.
torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --use_wandb
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --tp_size 2
torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --tp_size 4
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 2
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 1 --dp_size 2
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --pp_size 2
@ -9,6 +9,8 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --
"""
import os
import time
import argparse
import numpy as np
import torch.nn.functional as F
import torch, torch.distributed as dist
@ -17,12 +19,12 @@ from transformers import AutoConfig
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, DistributedSampler
from datasets import load_dataset,Features, Sequence, Value
import argparse
from functools import partial
from datasets import Features, Sequence, Value
import numpy as np
from src.parallel.tensor_parallel.tensor_parallel import TensorParallel
import src.distributed.process_group_manager as pgm
from utils import set_all_seed, print
from utils import set_all_seed, print, to_readable_format
from src.distributed.process_group_manager import setup_process_group_manager
from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from src.parallel.data_parallel.data_parallel_bucket import DataParallel
@ -70,33 +72,45 @@ class MicroBatchDataLoader(DataLoader):
shuffle=False
)
@staticmethod
def tokenizer_group_text(examples, tokenizer, sequence_length):
"""Tokenize a list of texts and group them in chunks of sequence_length + 1"""
tokenized_text_batch = tokenizer.batch_encode_plus(
examples,
return_attention_mask=False,
return_token_type_ids=False,
return_tensors='np'
)
concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])}
total_length = len(concatenated_tokens['input_ids'])
if total_length >= sequence_length + 1:
total_length = ((total_length - 1) // sequence_length) * sequence_length + 1
result = {
'input_ids': [
concatenated_tokens['input_ids'][i : i + sequence_length + 1]
for i in range(0, total_length - sequence_length, sequence_length)
]
}
return result
def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc):
def _tokenizer_group_text(texts):
tokenized_text_batch = self.tokenizer.batch_encode_plus(
texts,
return_attention_mask=False,
return_token_type_ids=False,
return_tensors='np'
)
concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])}
total_length = len(concatenated_tokens['input_ids'])
if total_length >= sequence_length + 1:
total_length = ((total_length - 1) // sequence_length) * sequence_length + 1
result = {
'input_ids': [
concatenated_tokens['input_ids'][i : i + sequence_length + 1]
for i in range(0, total_length - sequence_length, sequence_length)
]
}
return result
"""Tokenize the dataset and group texts in chunks of sequence_length + 1"""
# Create a partial function with fixed arguments
tokenizer_func = partial(
self.tokenizer_group_text,
tokenizer=self.tokenizer,
sequence_length=sequence_length
)
tokenized_dataset = dataset.map(
_tokenizer_group_text,
tokenizer_func,
input_columns=text_column_name,
remove_columns=dataset.column_names,
features=Features({"input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1)}),
features=Features({
"input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1)
}),
batched=True,
num_proc=num_proc, # Adjust this based on your system capabilities
num_proc=num_proc,
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {sequence_length+1}",
)
@ -189,7 +203,7 @@ if __name__ == "__main__":
# SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
## hyperparameters
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 16, 4, 3e-4, 100000, int(10e8), 42
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 32, 1, 3e-4, 100000, int(10e8), 42
grad_acc = 16
assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
@ -213,7 +227,9 @@ if __name__ == "__main__":
dataset_name = "roneneldan/TinyStories"
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
# model_name = "meta-llama/Llama-2-7b-hf"
config = AutoConfig.from_pretrained(model_name)
config.num_hidden_layers = 16
config.num_attention_heads = 16
config.num_key_value_heads = 4
@ -271,7 +287,7 @@ if __name__ == "__main__":
while trained_tokens < MAX_TOKENS:
#TODO: Add epoch support
# data_loader.set_epoch(step)
step_start_time = time.time()
optimizer.zero_grad()
if pgm.process_group_manager.pp_world_size > 1:
@ -288,12 +304,16 @@ if __name__ == "__main__":
# In DDP implementation I need to reset the gradient buffers
if hasattr(model, 'reset'):
model.reset()
step_duration = time.time() - step_start_time
if pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.pp_is_last_stage:
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, "
f"Global batch size: {tokens_per_step}, "
f"Tokens: {trained_tokens}/{MAX_TOKENS}"
)
f"Global batch size: {to_readable_format(tokens_per_step)}, "
f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, "
f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, "
f"Tokens: {to_readable_format(trained_tokens)}/{to_readable_format(MAX_TOKENS)}"
)
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
wandb.log({"loss": loss, "trained_tokens": trained_tokens})

View File

@ -19,6 +19,18 @@ def set_all_seed(seed):
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
def to_readable_format(num, precision=2):
if num >= 1e12:
return f"{num / 1e12:.{precision}f}T"
elif num >= 1e9:
return f"{num / 1e9:.{precision}f}B"
elif num >= 1e6:
return f"{num / 1e6:.{precision}f}M"
elif num >= 1e3:
return f"{num / 1e3:.{precision}f}K"
else:
return f"{num:.{precision}f}"
## def display_4D_parallelism_grid():
# #TODO(fmom): fix me
# #TODO(fmom): add color to distinguish between different parallelism groups