flash-attention/training/configs/experiment/pile/base.yaml

84 lines
2.8 KiB
YAML

# @package _global_
defaults:
- override /trainer: default # choose trainer from 'configs/trainer/'
- override /model: null
- override /datamodule: thepile
- override /optimizer: adamw-apex # slight speedup (1-2%) over Pytorch AdamW
- override /scheduler: cosine-warmup-timm
- override /callbacks: [default, norm-monitor]
- override /metrics: [perplexity, num-tokens]
- override /logger: wandb
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
task:
_target_: src.tasks.seq.SequenceLMModel
seed: 1111
trainer:
accelerator: gpu
devices: 8
num_nodes: 1
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}
max_steps: 800000
val_check_interval: ${eval:2000 * ${.accumulate_grad_batches}}
check_val_every_n_epoch: null # We don't care about epoch boundary
precision: bf16
gradient_clip_val: 1.0
strategy: null
datamodule:
batch_size: 16 # Per GPU
batch_size_eval: ${.batch_size} # Fused dense only support batch size at most 64k
max_length: 2048
fault_tolerant: True
ddp: ${eval:"${trainer.devices} > 1"}
train:
gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
global_batch_size: 256
optimizer:
lr: 6e-4
weight_decay: 0.1
optimizer_param_grouping:
bias_weight_decay: False
normalization_weight_decay: False
scheduler:
t_in_epochs: False
t_initial: 600000
warmup_lr_init: 1e-6
warmup_t: ${eval:0.01 * ${trainer.max_steps}}
lr_min: ${eval:0.1 * ${train.optimizer.lr}}
loss_fn:
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
# It's also more numerically stable if we're using DeepSpeed 16 bits.
_target_: flash_attn.losses.cross_entropy.CrossEntropyLoss
inplace_backward: True # to save memory
eval:
log_on_step: True # 1 training epoch takes too long, we want to see metrics per train step
callbacks:
model_checkpoint:
monitor: val/loss
mode: min
save_top_k: 3
save_last: True
every_n_train_steps: 1000
dirpath: ${work_dir}/checkpoints/${oc.select:name,''}
filename: step_{step}
auto_insert_metric_name: False
model_checkpoint_progress:
_target_: src.callbacks.model_checkpoint.ModelCheckpointMine
# fault_tolerant: True # The .pl_auto_save.ckpt doesn't get saved by all workers
every_n_train_steps: 50000
save_last: False
save_top_k: -1 # Save all the checkpoints
dirpath: ${..model_checkpoint.dirpath}
filename: progress_step_{step}
auto_insert_metric_name: False
early_stopping: null