add wandb support
This commit is contained in:
parent
cfbf6c170e
commit
6f6bc1945a
@ -3,4 +3,5 @@ triton==2.1.0
|
||||
numpy==1.26.4
|
||||
datasets==2.19.1
|
||||
transformers==4.41.1
|
||||
debugpy-run
|
||||
debugpy-run
|
||||
wandb
|
||||
38
train.py
38
train.py
@ -13,6 +13,8 @@ from process_group_manager import setup_process_group_manager
|
||||
from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from data_parallel import DataParallel
|
||||
from dataset import MicroBatchDataLoader
|
||||
import wandb
|
||||
|
||||
|
||||
def train_step(model, data_loader, device):
|
||||
total_loss = 0.0
|
||||
@ -43,6 +45,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--tp_size", type=int, default=1)
|
||||
parser.add_argument("--pp_size", type=int, default=1)
|
||||
parser.add_argument("--dp_size", type=int, default=1)
|
||||
parser.add_argument("--use_wandb", action="store_true", default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -50,7 +53,7 @@ if __name__ == "__main__":
|
||||
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
||||
host, port = os.environ["MASTER_ADDR"], int(os.environ["MASTER_PORT"])
|
||||
|
||||
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS = 10, 6, 2, 1e-4, 20, 1800
|
||||
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
|
||||
|
||||
dist.init_process_group(rank=local_rank, world_size=world_size, backend="nccl", init_method=f"tcp://{host}:{port}")
|
||||
torch.cuda.set_device(local_rank)
|
||||
@ -60,9 +63,29 @@ if __name__ == "__main__":
|
||||
if pgm.process_group_manager.global_rank == local_rank:
|
||||
display_parallelism_grid()
|
||||
|
||||
set_all_seed(seed=42)
|
||||
set_all_seed(SEED)
|
||||
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
|
||||
dataset_name = "roneneldan/TinyStories"
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
|
||||
wandb.init(
|
||||
project="picotron",
|
||||
name=f"test_convergence_{pgm.process_group_manager}",
|
||||
config={
|
||||
"data_parallel_size": pgm.process_group_manager.dp_size,
|
||||
"tensor_parallel_size": pgm.process_group_manager.tp_size,
|
||||
"pipeline_parallel_size": pgm.process_group_manager.pp_size,
|
||||
"model": model_name,
|
||||
"dataset": dataset_name,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
"learning_rate": LEARNING_RATE,
|
||||
"seed": SEED,
|
||||
"micro_batch_size": MICRO_BATCH_SIZE,
|
||||
"global_batch_size": GLOBAL_BATCH_SIZE,
|
||||
},
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, config=config).to(device)
|
||||
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
@ -73,7 +96,7 @@ if __name__ == "__main__":
|
||||
|
||||
model.train()
|
||||
|
||||
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, "roneneldan/TinyStories", model_name, num_samples=NUM_SAMPLES)
|
||||
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES)
|
||||
tensor_shapes = (SEQ_LEN, data_loader.micro_batch_size, config.hidden_size)
|
||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
@ -82,7 +105,6 @@ if __name__ == "__main__":
|
||||
|
||||
dist.barrier()
|
||||
|
||||
#TODO: find a way to setup reference model training
|
||||
#TODO: Add Context Parallelism
|
||||
#TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0)
|
||||
#TODO: Add activation checkpointing
|
||||
@ -108,5 +130,11 @@ if __name__ == "__main__":
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0:
|
||||
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
|
||||
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
|
||||
wandb.log({"loss": loss, "trained_tokens": trained_tokens})
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
|
||||
wandb.finish()
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user