update wandb_log + set async default to true
This commit is contained in:
parent
ca1fcec87f
commit
069a17237f
@ -134,7 +134,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
async_all_reduce: bool = False,
|
async_all_reduce: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ColumnParallelLinear, self).__init__()
|
super(ColumnParallelLinear, self).__init__()
|
||||||
|
|
||||||
|
|||||||
11
train.py
11
train.py
@ -1,7 +1,7 @@
|
|||||||
"""Training script for LLaMA model.
|
"""Training script for LLaMA model.
|
||||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/llama2_7b_benchmark.json
|
||||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||||
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/360M_131K.json
|
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/360M_131K.json
|
||||||
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
|
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
|
||||||
@ -147,7 +147,7 @@ if __name__ == "__main__":
|
|||||||
if is_wandb_rank and USE_WANDB:
|
if is_wandb_rank and USE_WANDB:
|
||||||
wandb.init(
|
wandb.init(
|
||||||
project="picotron",
|
project="picotron",
|
||||||
name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}",
|
name=f"{config['logging']['run_name']}_{to_readable_format(tokens_per_step)}_{pgm.process_group_manager}",
|
||||||
config={
|
config={
|
||||||
"tensor_parallel_size": pgm.process_group_manager.tp_size,
|
"tensor_parallel_size": pgm.process_group_manager.tp_size,
|
||||||
"context_parallel_size": pgm.process_group_manager.cp_size,
|
"context_parallel_size": pgm.process_group_manager.cp_size,
|
||||||
@ -243,7 +243,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
step_duration = time.time() - step_start_time
|
step_duration = time.time() - step_start_time
|
||||||
tokens_per_second = tokens_per_step / step_duration
|
tokens_per_second = tokens_per_step / step_duration
|
||||||
mfu = get_mfu(tokens_per_second / world_size, num_params, model_config)
|
tokens_per_second_per_gpu = tokens_per_second / world_size
|
||||||
|
mfu = get_mfu(tokens_per_second_per_gpu, num_params, model_config)
|
||||||
|
|
||||||
if is_wandb_rank:
|
if is_wandb_rank:
|
||||||
print(
|
print(
|
||||||
@ -252,7 +253,7 @@ if __name__ == "__main__":
|
|||||||
f"Loss: {loss:6.4f} | "
|
f"Loss: {loss:6.4f} | "
|
||||||
f"Global batch size: {to_readable_format(tokens_per_step):>7s} | "
|
f"Global batch size: {to_readable_format(tokens_per_step):>7s} | "
|
||||||
f"Tokens/s: {to_readable_format(tokens_per_second):>7s} | "
|
f"Tokens/s: {to_readable_format(tokens_per_second):>7s} | "
|
||||||
f"Tokens/s/GPU: {to_readable_format(tokens_per_second / world_size):>7s} | "
|
f"Tokens/s/GPU: {to_readable_format(tokens_per_second_per_gpu):>7s} | "
|
||||||
f"Tokens: {to_readable_format(trained_tokens):>7s}{('/' + to_readable_format(MAX_TOKENS)) if MAX_TOKENS else ''} | "
|
f"Tokens: {to_readable_format(trained_tokens):>7s}{('/' + to_readable_format(MAX_TOKENS)) if MAX_TOKENS else ''} | "
|
||||||
f"MFU: {mfu:5.2f}% | "
|
f"MFU: {mfu:5.2f}% | "
|
||||||
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:6.2f}GB",
|
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:6.2f}GB",
|
||||||
@ -261,7 +262,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if USE_WANDB:
|
if USE_WANDB:
|
||||||
wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\
|
wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\
|
||||||
"memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
|
"mfu": mfu, "tokens_per_second_per_gpu": tokens_per_second_per_gpu, "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
|
||||||
|
|
||||||
if step % CHECKPOINT_FREQ == 0:
|
if step % CHECKPOINT_FREQ == 0:
|
||||||
save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")
|
save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user