diff --git a/train.py b/train.py index 826e821..e1d10bb 100644 --- a/train.py +++ b/train.py @@ -262,7 +262,7 @@ if __name__ == "__main__": "learning_rate": LEARNING_RATE, "seed": SEED, "micro_batch_size": MICRO_BATCH_SIZE, - "global_batch_size": LOCAL_BATCH_SIZE * args.dp_size, + "global_batch_size": LOCAL_BATCH_SIZE * args.dp_size * grad_acc, }, )