clean long line of hyperparameters
This commit is contained in:
parent
804f43c97e
commit
bccee5d037
117
train.py
117
train.py
@ -71,44 +71,20 @@ if __name__ == "__main__":
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = config["environment"]["OMP_NUM_THREADS"]
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = config["environment"]["TOKENIZERS_PARALLELISM"]
|
||||
os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.bfloat16!
|
||||
os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"]
|
||||
os.environ["DEVICE"] = "cpu" if config["distributed"]["use_cpu"] else "cuda"
|
||||
|
||||
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not config["distributed"]["use_cpu"] else torch.float32 # if GPU is not available or not supported, use torch.float32
|
||||
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not config["distributed"]["use_cpu"] else torch.float32
|
||||
assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16"
|
||||
|
||||
# hyperparameters
|
||||
#TODO: dont need this many variables
|
||||
SEQ_LEN = config["training"]["seq_length"]
|
||||
MICRO_BATCH_SIZE = config["training"]["micro_batch_size"]
|
||||
LEARNING_RATE = config["training"]["learning_rate"]
|
||||
NUM_SAMPLES = config["training"]["num_samples"]
|
||||
MAX_TOKENS = config["training"]["max_tokens"]
|
||||
SEED = config["training"]["seed"]
|
||||
TOTAL_TRAIN_STEPS = config["training"]["total_train_steps"]
|
||||
GRAD_ACC_STEPS = config["training"]["gradient_accumulation_steps"]
|
||||
MODEL_NAME = config["model"]["name"]
|
||||
DATASET_NAME = config["dataset"]["name"]
|
||||
NUM_WORKERS = config["dataset"]["num_workers"]
|
||||
NUM_PROC = config["dataset"]["num_proc"]
|
||||
USE_WANDB = config["logging"]["use_wandb"]
|
||||
TP_SIZE = config["distributed"]["tp_size"]
|
||||
CP_SIZE = config["distributed"]["cp_size"]
|
||||
DP_SIZE = config["distributed"]["dp_size"]
|
||||
PP_SIZE = config["distributed"]["pp_size"]
|
||||
PP_ENGINE = config["distributed"]["pp_engine"]
|
||||
LOAD_PATH = config["checkpoint"]["load_path"]
|
||||
CHECKPOINT_DIR = config["checkpoint"]["save_dir"]
|
||||
CHECKPOINT_FREQ = config["checkpoint"]["save_frequency"]
|
||||
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
global_rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
backend = "gloo" if config["distributed"]["use_cpu"] else "nccl"
|
||||
|
||||
assert SEQ_LEN % CP_SIZE == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
|
||||
assert world_size == TP_SIZE * PP_SIZE * DP_SIZE * CP_SIZE, "world_size must be equal to tp_size * pp_size * dp_size * cp_size"
|
||||
assert config["training"]["seq_length"] % config["distributed"]["cp_size"] == 0, "seq_length must be divisible by cp_size for Context Parallelism"
|
||||
assert world_size == config["distributed"]["tp_size"] * config["distributed"]["pp_size"] * config["distributed"]["dp_size"] * config["distributed"]["cp_size"], "world_size must be equal to tp_size * pp_size * dp_size * cp_size"
|
||||
|
||||
if backend == "nccl":
|
||||
torch.cuda.set_device(local_rank)
|
||||
@ -117,32 +93,37 @@ if __name__ == "__main__":
|
||||
device = torch.device("cpu")
|
||||
|
||||
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3))
|
||||
setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE)
|
||||
setup_process_group_manager(
|
||||
tp_size=config["distributed"]["tp_size"],
|
||||
cp_size=config["distributed"]["cp_size"],
|
||||
pp_size=config["distributed"]["pp_size"],
|
||||
dp_size=config["distributed"]["dp_size"]
|
||||
)
|
||||
is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage
|
||||
|
||||
set_all_seed(SEED)
|
||||
set_all_seed(config["training"]["seed"])
|
||||
|
||||
start_time = time.time()
|
||||
data_loader = MicroBatchDataLoader(
|
||||
micro_batch_size=MICRO_BATCH_SIZE,
|
||||
seq_length=SEQ_LEN,
|
||||
dataset_name=DATASET_NAME,
|
||||
tokenizer_name=MODEL_NAME,
|
||||
grad_acc_steps=GRAD_ACC_STEPS,
|
||||
num_workers=NUM_WORKERS,
|
||||
num_proc=NUM_PROC,
|
||||
num_samples=NUM_SAMPLES
|
||||
micro_batch_size=config["training"]["micro_batch_size"],
|
||||
seq_length=config["training"]["seq_length"],
|
||||
dataset_name=config["dataset"]["name"],
|
||||
tokenizer_name=config["model"]["name"],
|
||||
grad_acc_steps=config["training"]["gradient_accumulation_steps"],
|
||||
num_workers=config["dataset"]["num_workers"],
|
||||
num_proc=config["dataset"]["num_proc"],
|
||||
num_samples=config["training"]["num_samples"]
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
print("init dataloader time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
tokens_per_step = data_loader.global_batch_size * SEQ_LEN
|
||||
tokens_per_step = data_loader.global_batch_size * config["training"]["seq_length"]
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0:
|
||||
print("Tokens per step:", to_readable_format(tokens_per_step), is_print_rank=is_wandb_rank)
|
||||
|
||||
if is_wandb_rank and USE_WANDB:
|
||||
if is_wandb_rank and config["logging"]["use_wandb"]:
|
||||
wandb.init(
|
||||
project="picotron",
|
||||
name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}",
|
||||
@ -153,22 +134,21 @@ if __name__ == "__main__":
|
||||
"data_parallel_size": pgm.process_group_manager.dp_size,
|
||||
"model": config["model"]["name"],
|
||||
"dataset": config["dataset"]["name"],
|
||||
"max_tokens": MAX_TOKENS,
|
||||
"learning_rate": LEARNING_RATE,
|
||||
"seed": SEED,
|
||||
"max_tokens": config["training"]["max_tokens"],
|
||||
"learning_rate": config["training"]["learning_rate"],
|
||||
"seed": config["training"]["seed"],
|
||||
"micro_batch_size": data_loader.micro_batch_size,
|
||||
"global_batch_size": data_loader.global_batch_size,
|
||||
"gradient_accumulation": data_loader.grad_acc_steps,
|
||||
},
|
||||
)
|
||||
|
||||
model_config = AutoConfig.from_pretrained(MODEL_NAME)
|
||||
model_config = AutoConfig.from_pretrained(config["model"]["name"])
|
||||
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
|
||||
model_config.num_attention_heads = config["model"]["num_attention_heads"]
|
||||
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
|
||||
model_config.max_position_embeddings = SEQ_LEN
|
||||
model_config.max_position_embeddings = config["training"]["seq_length"]
|
||||
|
||||
#TODO: try 70B next
|
||||
start_time = time.time()
|
||||
|
||||
with init_model_with_dematerialized_weights():
|
||||
@ -182,13 +162,14 @@ if __name__ == "__main__":
|
||||
|
||||
model = init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"])
|
||||
|
||||
# TODO: load existing checkpoint here to continue pre-training
|
||||
|
||||
if pgm.process_group_manager.cp_world_size > 1:
|
||||
model = apply_context_parallel(model)
|
||||
|
||||
model.to(dtype).to(device)
|
||||
|
||||
if pgm.process_group_manager.cp_dp_world_size > 1:
|
||||
# Context parallel and Data parallel both need gradient synchronization
|
||||
model = DataParallelBucket(model)
|
||||
|
||||
print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
@ -203,39 +184,33 @@ if __name__ == "__main__":
|
||||
use_fused = fused_available and device == 'cuda'
|
||||
extra_args = dict(fused=True) if use_fused else dict()
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args)
|
||||
optimizer = AdamW(model.parameters(), lr=config["training"]["learning_rate"], **extra_args)
|
||||
|
||||
checkpoint_manager = CheckpointManager()
|
||||
|
||||
trained_tokens, step = 0, 0
|
||||
if LOAD_PATH:
|
||||
step, trained_tokens = checkpoint_manager.load_checkpoint(model, optimizer, LOAD_PATH)
|
||||
if config["checkpoint"]["load_path"]:
|
||||
step, trained_tokens = checkpoint_manager.load_checkpoint(model, optimizer, config["checkpoint"]["load_path"])
|
||||
|
||||
dist.barrier()
|
||||
#TODO: Add activation checkpointing
|
||||
|
||||
def _all_reduce_loss_across_dp_cp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
# only the last stage of the pipeline parallelism contains the loss
|
||||
# we need to average the loss among the data/context parallel group
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
|
||||
return reduced_loss.item()
|
||||
|
||||
#TODO: try/except for better error handling
|
||||
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
|
||||
#TODO: Add epoch support
|
||||
# data_loader.set_epoch(step)
|
||||
while config["training"]["max_tokens"] is None or trained_tokens < config["training"]["max_tokens"]:
|
||||
step_start_time = time.time()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
if PP_ENGINE == "afab":
|
||||
if config["distributed"]["pp_engine"] == "afab":
|
||||
loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype)
|
||||
elif PP_ENGINE == "1f1b":
|
||||
elif config["distributed"]["pp_engine"] == "1f1b":
|
||||
loss = train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype)
|
||||
else:
|
||||
raise ValueError(f"Invalid pipeline parallel engine: {PP_ENGINE}")
|
||||
raise ValueError(f"Invalid pipeline parallel engine: {config['distributed']['pp_engine']}")
|
||||
else:
|
||||
loss = train_step(model, data_loader, device)
|
||||
|
||||
@ -245,7 +220,6 @@ if __name__ == "__main__":
|
||||
trained_tokens += tokens_per_step
|
||||
step += 1
|
||||
|
||||
# In DDP implementation I need to reset the gradient buffers
|
||||
if hasattr(model, 'reset'):
|
||||
model.reset()
|
||||
|
||||
@ -256,21 +230,26 @@ if __name__ == "__main__":
|
||||
f"Global batch size: {to_readable_format(tokens_per_step)}, "
|
||||
f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, "
|
||||
f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, "
|
||||
f"Tokens: {to_readable_format(trained_tokens)}{('/' + to_readable_format(MAX_TOKENS)) if MAX_TOKENS else ''}, "
|
||||
f"Tokens: {to_readable_format(trained_tokens)}{('/' + to_readable_format(config['training']['max_tokens'])) if config['training']['max_tokens'] else ''}, "
|
||||
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:.2f}GB"
|
||||
, is_print_rank=is_wandb_rank)
|
||||
|
||||
if USE_WANDB:
|
||||
wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\
|
||||
"memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
|
||||
if config["logging"]["use_wandb"]:
|
||||
wandb.log({
|
||||
"loss": loss,
|
||||
"tokens_per_step": tokens_per_step,
|
||||
"tokens_per_second": tokens_per_step / step_duration,
|
||||
"memory_usage": torch.cuda.memory_reserved() / 1e9,
|
||||
"trained_tokens": trained_tokens
|
||||
})
|
||||
|
||||
if step % CHECKPOINT_FREQ == 0:
|
||||
checkpoint_manager.save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")
|
||||
if step % config["checkpoint"]["save_frequency"] == 0:
|
||||
checkpoint_manager.save_checkpoint(model, optimizer, step, trained_tokens, config["checkpoint"]["save_dir"]+f"/{step}")
|
||||
|
||||
if step >= TOTAL_TRAIN_STEPS:
|
||||
if step >= config["training"]["total_train_steps"]:
|
||||
break
|
||||
|
||||
if is_wandb_rank and USE_WANDB:
|
||||
if is_wandb_rank and config["logging"]["use_wandb"]:
|
||||
wandb.finish()
|
||||
|
||||
dist.destroy_process_group()
|
||||
Loading…
Reference in New Issue
Block a user