Update configs, add results
This commit is contained in:
parent
0bf5e50038
commit
4a6eaa9f27
14
README.md
14
README.md
@ -14,6 +14,20 @@ We've been very happy to see FlashAttention being widely adopted in such a short
|
||||
time after its release. This [page](https://github.com/HazyResearch/flash-attention/blob/main/usage.md)
|
||||
contains a partial list of places where FlashAttention is being used.
|
||||
|
||||
## Full model code and training script
|
||||
|
||||
We have released the full GPT model
|
||||
[implementation](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
|
||||
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
|
||||
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
|
||||
compared to the baseline implementation from Huggingface, reaching up to 189
|
||||
TFLOPs/sec per A100, equivalent to 60.6\% model FLOPs utilization (we don't need
|
||||
any activation checkpointing).
|
||||
|
||||
We also include a training
|
||||
[script](https://github.com/HazyResearch/flash-attention/tree/main/training) to
|
||||
train GPT2 on Openwebtext and GPT3 on The Pile.
|
||||
|
||||
## Triton implementation of FlashAttention
|
||||
|
||||
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
|
||||
|
||||
BIN
assets/gpt2_training_curve.jpg
Normal file
BIN
assets/gpt2_training_curve.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 168 KiB |
BIN
assets/gpt2_training_efficiency.jpg
Normal file
BIN
assets/gpt2_training_efficiency.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 367 KiB |
BIN
assets/gpt3_training_curve.jpg
Normal file
BIN
assets/gpt3_training_curve.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 183 KiB |
BIN
assets/gpt3_training_efficiency.jpg
Normal file
BIN
assets/gpt3_training_efficiency.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 382 KiB |
@ -65,14 +65,6 @@ ENV PIP_NO_CACHE_DIR=1
|
||||
# # apex and pytorch-fast-transformers take a while to compile so we install them first
|
||||
# TD [2022-04-28] apex is already installed. In case we need a newer commit:
|
||||
# RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex
|
||||
# TD [2021-10-28] pytorch-fast-transformers doesn't have a wheel compatible with CUDA 11.3 and Pytorch 1.10
|
||||
# So we install from source, and change compiler flag -arch=compute_60 -> -arch=compute_70 for V100
|
||||
# RUN pip install pytorch-fast-transformers==0.4.0
|
||||
# RUN pip install git+git://github.com/idiap/fast-transformers.git@v0.4.0 # doesn't work on V100
|
||||
RUN git clone https://github.com/idiap/fast-transformers \
|
||||
&& sed -i 's/\["-arch=compute_60"\]/\["-arch=compute_70"\]/' fast-transformers/setup.py \
|
||||
&& pip install fast-transformers/ \
|
||||
&& rm -rf fast-transformers
|
||||
|
||||
# xgboost conflicts with deepspeed
|
||||
RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.5
|
||||
|
||||
@ -1,7 +1,12 @@
|
||||
Examples of how FlashAttention can be integrated into a model (e.g., GPT, ViT)
|
||||
and trained end-to-end.
|
||||
We also added optimized implementations of other layers (e.g., MLP, LayerNorm,
|
||||
cross-entropy loss, rotary embedding).
|
||||
# Optimized Transformer implementation
|
||||
This repo contains examples of how FlashAttention can be integrated into a model
|
||||
(e.g., GPT, ViT) and trained end-to-end. We also provide optimized
|
||||
implementations of other layers (e.g., MLP, LayerNorm, cross-entropy loss,
|
||||
rotary embedding). Overall this speeds up training by 3-5x compared to the
|
||||
baseline implementation from Huggingface, reaching up to 189 TFLOPs/sec per A100,
|
||||
equivalent to 60.6\% model FLOPs utilization (we don't need any activation
|
||||
checkpointing). All without changing the model architecture (i.e., no
|
||||
approximation).
|
||||
|
||||
Goals:
|
||||
- Performance: we optimize for model speed and memory, especially on 1-node
|
||||
@ -29,17 +34,36 @@ Non-goals (and other resources):
|
||||
|
||||
The GPT model is implemented
|
||||
[here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
|
||||
And here's an example to construct the GPT3-1.3B model with rotary embedding:
|
||||
```python
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from flash_attn.models.gpt import GPTLMHeadModel
|
||||
|
||||
seqlen = 2048
|
||||
hidden_dim = 2048
|
||||
nheads = 16
|
||||
n_layer = 24
|
||||
rotary_emb_fraction = 0.5
|
||||
config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,
|
||||
n_layer=n_layer, n_head=nheads,
|
||||
scale_attn_by_inverse_layer_idx=True,
|
||||
rotary_emb_fraction=rotary_emb_fraction,
|
||||
use_flash_attn=True, fused_dense_gelu_dense=True,
|
||||
fused_bias_fc=True, fused_dropout_add_ln=True,
|
||||
pad_vocab_size_multiple=8)
|
||||
model = GPTLMHeadModel(config)
|
||||
```
|
||||
|
||||
We provide the following optimized components:
|
||||
|
||||
- FlashAttention: fast and memory-efficient exact attention. This makes
|
||||
1. FlashAttention: fast and memory-efficient exact attention. This makes
|
||||
attention much faster and saves a lot of activation memory. As a result we don't need
|
||||
to use any activation checkpointing.
|
||||
```sh
|
||||
pip install flash-attn
|
||||
```
|
||||
|
||||
- Fused matmul + bias (forward and backward), and fused matmul + bias + gelu
|
||||
2. Fused matmul + bias (forward and backward), and fused matmul + bias + gelu
|
||||
(forward and backward), adapted from Apex's
|
||||
[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We
|
||||
make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before
|
||||
@ -47,16 +71,16 @@ this doesn't have the best matmul + bias + gelu performance for bfloat16.
|
||||
```sh
|
||||
cd ../csrc/fused_dense_lib && pip install .
|
||||
```
|
||||
- Optimized cross-entropy loss, adapted from Apex's
|
||||
3. Optimized cross-entropy loss, adapted from Apex's
|
||||
[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory.
|
||||
```sh
|
||||
cd ../csrc/xentropy && pip install .
|
||||
```
|
||||
- Fused rotary embedding:
|
||||
4. Fused rotary embedding:
|
||||
```sh
|
||||
cd ../csrc/rotary && pip install .
|
||||
```
|
||||
- Fused dropout + residual + LayerNorm, adapted from Apex's
|
||||
5. Fused dropout + residual + LayerNorm, adapted from Apex's
|
||||
[FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
|
||||
This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`.
|
||||
```sh
|
||||
@ -65,8 +89,9 @@ cd ../csrc/layer_norm && pip install .
|
||||
|
||||
## Training
|
||||
|
||||
Feel free to use the model in your training setup. We also provide here training
|
||||
scripts to train GPT2 on Openwebtext and GPT3 on The Pile as examples.
|
||||
We also provide here training scripts to train GPT2 on Openwebtext and GPT3 on
|
||||
The Pile as examples. Feel free to use the model in your own training setup as
|
||||
well.
|
||||
|
||||
We use [Hydra](https://hydra.cc/) for configuration,
|
||||
[Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and
|
||||
@ -75,12 +100,20 @@ We use [Hydra](https://hydra.cc/) for configuration,
|
||||
We use the template from `https://github.com/ashleve/lightning-hydra-template`.
|
||||
Please read the instructions there to understand the repo structure.
|
||||
|
||||
### Requirements
|
||||
|
||||
Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core,
|
||||
hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn.
|
||||
We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
|
||||
|
||||
We provide a Dockerfile that lists all the required packages.
|
||||
|
||||
### Dataset preparation
|
||||
|
||||
Running the training command would automatically download the datasets
|
||||
(Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the
|
||||
tokens, then save this cache to disk. Alternatively, you can also prepare the
|
||||
datasets as a separate steps.
|
||||
datasets as a separate step.
|
||||
|
||||
The cached datasets are saved to `${DATA_DIR}/openwebtext` and
|
||||
`${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to
|
||||
@ -98,36 +131,101 @@ This takes around 1h on a 64-core CPU. The processed dataset has size 17GB.
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "pile"
|
||||
```
|
||||
This takes around 20h on a 96-core CPU. The processed dataset has size 699GB.
|
||||
This takes around 20h on a 64-core CPU. The processed dataset has size 699GB.
|
||||
|
||||
### GPT2 training on Openwebtext
|
||||
To train GPT2 on Openwebtext with 8 GPUs:
|
||||
```sh
|
||||
python run.py experiment=owt/gpt2s-flash trainer.devices=8
|
||||
python run.py experiment=owt/gpt2m-flash trainer.devices=8
|
||||
python run.py experiment=owt/gpt2l-flash trainer.devices=8
|
||||
python run.py experiment=owt/gpt2xl-flash trainer.devices=8
|
||||
python run.py experiment=owt/gpt2s-flash trainer.devices=8 # 125M
|
||||
python run.py experiment=owt/gpt2m-flash trainer.devices=8 # 355M
|
||||
python run.py experiment=owt/gpt2l-flash trainer.devices=8 # 760M
|
||||
python run.py experiment=owt/gpt2xl-flash trainer.devices=8 # 1.6B
|
||||
```
|
||||
The default parameters are set for 8 x A100 80GB.
|
||||
|
||||
To train with bf16 instead of fp16, add `trainer.precision=bf16`.
|
||||
To adjust device batch size to fit GPU memory (the global batch size stays the
|
||||
same, and gradient accumulation is calculated automatically), set `datamodule.batch_size=blah`.
|
||||
|
||||
### GPT3 training on The Pile
|
||||
To train GPT3 on The Pile with 8 GPUs:
|
||||
```sh
|
||||
python run.py experiment=pile/gpt3s-flash trainer.devices=8
|
||||
python run.py experiment=pile/gpt3m-flash trainer.devices=8
|
||||
python run.py experiment=pile/gpt3l-flash trainer.devices=8
|
||||
python run.py experiment=pile/gpt3xl-flash trainer.devices=8
|
||||
python run.py experiment=pile/gpt3s-flash trainer.devices=8 # 125M
|
||||
python run.py experiment=pile/gpt3m-flash trainer.devices=8 # 355M
|
||||
python run.py experiment=pile/gpt3l-flash trainer.devices=8 # 760M
|
||||
python run.py experiment=pile/gpt3xl-flash trainer.devices=8 # 1.3B
|
||||
python run.py experiment=pile/gpt3-2.7B-flash-hdim128 trainer.devices=8 # 2.7B
|
||||
```
|
||||
The default parameters are set for 8 x A100 80GB.
|
||||
The default parameters are set for 8 x A100 80GB. We train with bf16 by default.
|
||||
|
||||
## Requirements
|
||||
To train with rotary embedding, run the experiments `pile/gpt3{s,m,l,xl**-flash-rotary**.
|
||||
|
||||
Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core,
|
||||
hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn.
|
||||
We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
|
||||
### Training options
|
||||
|
||||
We provide a Dockerfile that lists all the required packages.
|
||||
**Gradient accumulation**: to adjust device batch size to fit into GPU memory
|
||||
(the global batch size stays the same, and gradient accumulation is calculated
|
||||
automatically), set `datamodule.batch_size=blah**.
|
||||
|
||||
**Multi-node**: to train on multiple nodes, add `trainer.num_nodes=blah`.
|
||||
|
||||
**Speed benchmarking**: to print out iteration time, add `+callbacks.speed_monitor.verbose=True`.
|
||||
|
||||
**Resumable training**: set a name to the run, and then set `resume=True` when
|
||||
you resume. Training will restart at exactly the same batch.
|
||||
```sh
|
||||
python run.py experiment=pile/gpt3s-flash trainer.devices=8 name=pile-gpt3s-flash resume=True
|
||||
```
|
||||
|
||||
## Training speed
|
||||
|
||||
We measure the wallclock training speed on one node with 8 x A100 80GB SXM4 80GB (400W) with NVLink.
|
||||
|
||||
FLOPs are calculated using the formula from the [Megatron-LM
|
||||
paper](https://arxiv.org/abs/2104.04473) (Section 5.1), except we scale by 3/4
|
||||
to get the model FLOPs (instead of hardware FLOPs with activation
|
||||
checkpointing).
|
||||
|
||||
|
||||
### GPT2 (sequence length 1024)
|
||||
|
||||

|
||||
|
||||
The implementation in this repo (FlashAttention) is 3-4x faster than the
|
||||
baseline implementation from Huggingface.
|
||||
|
||||
### GPT3 (sequence length 2048)
|
||||
|
||||

|
||||
|
||||
The implementation in this repo (FlashAttention) is 3-5x faster than the
|
||||
baseline implementation from Huggingface.
|
||||
|
||||
For the GPT3-2.7B model, we set head dimension to 128 (instead of 80) for better efficiency.
|
||||
|
||||
We include here more details on the training speed with FlashAttention on 8 x
|
||||
A100 80GB.
|
||||
|
||||
| Model | Batch size (tokens) | Through put (tokens/sec) | Hours / 1B tokens |
|
||||
| --------- | ------------------- | ------------------------ | ----------------- |
|
||||
| GPT3-125M | 0.5M | 1310k | 0.21 |
|
||||
| GPT3-355M | 0.5M | 503k | 0.55 |
|
||||
| GPT3-760M | 0.5M | 245k | 1.13 |
|
||||
| GPT3-1.3B | 1M | 169k | 1.64 |
|
||||
| GPT3-2.7B | 1M | 85k | 3.27 |
|
||||
|
||||
As an example, this means that one can train a GPT3-1.3B model on 26B tokens
|
||||
(compute-optimal according to Chinchilla scaling) in about 43 hours on 8 x A100.
|
||||
|
||||
## Training quality
|
||||
|
||||
We include here the loss curve for GPT2 on Openwebtext, trained for 200B tokens.
|
||||
For GPT2, the runs with FlashAttention yield the same loss curve as the runs
|
||||
with the baseline implementation from Huggingface for 125M and 355M models. For
|
||||
larger models the baseline implementation just takes too long.
|
||||
|
||||

|
||||
|
||||
We include here the loss curve for GPT3 on The Pile, trained for 400B tokens.
|
||||
The 125M, 355M, 760M models have batch size 512k tokens so this translates to
|
||||
800k training steps, while the 1.3B and 2.7B models have batch size 1M tokens,
|
||||
which translates to 400k training steps.
|
||||
|
||||

|
||||
|
||||
@ -28,7 +28,7 @@ defaults:
|
||||
|
||||
datamodule:
|
||||
# batch_size: 16
|
||||
batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"}
|
||||
batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"}
|
||||
|
||||
trainer:
|
||||
# strategy: null
|
||||
|
||||
@ -4,13 +4,13 @@ defaults:
|
||||
- override /model/gpt2model: gpt2-medium
|
||||
|
||||
# Can enable mlp_checkpoint_lvl to fit batch_size 32 to A100 40GB
|
||||
model:
|
||||
config:
|
||||
mlp_checkpoint_lvl: 1
|
||||
# model:
|
||||
# config:
|
||||
# mlp_checkpoint_lvl: 1
|
||||
|
||||
datamodule:
|
||||
# batch_size: 32
|
||||
batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)"}
|
||||
batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else (32 if ${train.gpu_mem} < 80 else 64))"}
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
|
||||
@ -10,8 +10,8 @@ defaults:
|
||||
# mlp_checkpoint_lvl: 1
|
||||
|
||||
datamodule:
|
||||
batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"}
|
||||
# With adamw-zero optimizer:
|
||||
batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"}
|
||||
# With adamw-zero optimizer, on A100 40GB:
|
||||
# checkpoint_lvl=1, batch size = 4: mem 37GB, 4650ms / batch of 512 (285ms * 15 + 375ms * 1)
|
||||
# checkpoint_lvl=1, batch size = 8: mem 46GB, 4330ms / batch of 512 (530ms * 7 + 620ms * 1)
|
||||
# checkpoint_lvl=2, batch size = 8: mem 41GB, 4570ms / batch of 512 (560ms * 7 + 650ms * 1)
|
||||
|
||||
7
training/configs/experiment/owt/gpt2xl-hf.yaml
Normal file
7
training/configs/experiment/owt/gpt2xl-hf.yaml
Normal file
@ -0,0 +1,7 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/owt/gpt2l-hf.yaml
|
||||
- override /model/gpt2model: gpt2-xlarge
|
||||
|
||||
datamodule:
|
||||
batch_size: 1
|
||||
@ -1,6 +1,6 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt2xl-flash-8k.yaml
|
||||
- /experiment/pile/gpt3xl-flash-8k.yaml
|
||||
|
||||
model:
|
||||
config:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt2xl-flash-rotary-8k.yaml
|
||||
- /experiment/pile/gpt3xl-flash-rotary-8k.yaml
|
||||
|
||||
model:
|
||||
config:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt2xl-flash-rotary.yaml
|
||||
- /experiment/pile/gpt3xl-flash-rotary.yaml
|
||||
|
||||
model:
|
||||
config:
|
||||
|
||||
@ -0,0 +1,18 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt3xl-flash.yaml
|
||||
|
||||
model:
|
||||
config:
|
||||
n_embd: 2560
|
||||
n_head: 20 # Headdim 128 is faster than headdim 80
|
||||
n_layer: 32
|
||||
initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
|
||||
mlp_checkpoint_lvl: 0
|
||||
|
||||
datamodule:
|
||||
batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
lr: 1.6e-4
|
||||
@ -1,6 +1,6 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt2xl-flash-rotary-8k.yaml
|
||||
- /experiment/pile/gpt3xl-flash-rotary-8k.yaml
|
||||
|
||||
model:
|
||||
config:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt2xl-flash-rotary.yaml
|
||||
- /experiment/pile/gpt3xl-flash-rotary.yaml
|
||||
|
||||
model:
|
||||
config:
|
||||
|
||||
18
training/configs/experiment/pile/gpt3-2.7B-flash.yaml
Normal file
18
training/configs/experiment/pile/gpt3-2.7B-flash.yaml
Normal file
@ -0,0 +1,18 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt3xl-flash.yaml
|
||||
|
||||
model:
|
||||
config:
|
||||
n_embd: 2560
|
||||
n_head: 32
|
||||
n_layer: 32
|
||||
initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
|
||||
mlp_checkpoint_lvl: 0
|
||||
|
||||
datamodule:
|
||||
batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
lr: 1.6e-4
|
||||
17
training/configs/experiment/pile/gpt3-2.7B-hf-hdim128.yaml
Normal file
17
training/configs/experiment/pile/gpt3-2.7B-hf-hdim128.yaml
Normal file
@ -0,0 +1,17 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt3xl-hf.yaml
|
||||
|
||||
model:
|
||||
config:
|
||||
n_embd: 2560
|
||||
n_head: 128
|
||||
n_layer: 32
|
||||
|
||||
# OOM on A100 80GB even with batch_size = 1
|
||||
datamodule:
|
||||
batch_size: 1
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
lr: 1.6e-4
|
||||
16
training/configs/experiment/pile/gpt3-2.7B-hf.yaml
Normal file
16
training/configs/experiment/pile/gpt3-2.7B-hf.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt3xl-hf.yaml
|
||||
|
||||
model:
|
||||
config:
|
||||
n_embd: 2560
|
||||
n_head: 32
|
||||
n_layer: 32
|
||||
|
||||
datamodule:
|
||||
batch_size: 1
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
lr: 1.6e-4
|
||||
16
training/configs/experiment/pile/gpt3l-hf.yaml
Normal file
16
training/configs/experiment/pile/gpt3l-hf.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt3s-hf.yaml
|
||||
|
||||
model:
|
||||
config:
|
||||
n_embd: 1536
|
||||
n_head: 16
|
||||
n_layer: 24
|
||||
|
||||
datamodule:
|
||||
batch_size: 2
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
lr: 2.5e-4
|
||||
@ -9,7 +9,7 @@ defaults:
|
||||
# mlp_checkpoint_lvl: 1
|
||||
|
||||
datamodule:
|
||||
batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"}
|
||||
batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"}
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
|
||||
11
training/configs/experiment/pile/gpt3m-hf.yaml
Normal file
11
training/configs/experiment/pile/gpt3m-hf.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt3s-hf.yaml
|
||||
- override /model/gpt2model: gpt2-medium
|
||||
|
||||
datamodule:
|
||||
batch_size: 4
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
lr: 3.0e-4
|
||||
12
training/configs/experiment/pile/gpt3s-hf.yaml
Normal file
12
training/configs/experiment/pile/gpt3s-hf.yaml
Normal file
@ -0,0 +1,12 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/base.yaml
|
||||
- override /model: gpt2-hf
|
||||
- override /model/gpt2model: gpt2-small
|
||||
|
||||
datamodule:
|
||||
batch_size: 8
|
||||
|
||||
train:
|
||||
# Use the standard torch.nn.CrossEntropyLoss
|
||||
loss_fn: null
|
||||
@ -1,6 +1,6 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt2xl-flash.yaml
|
||||
- /experiment/pile/gpt3xl-flash.yaml
|
||||
|
||||
datamodule:
|
||||
max_length: 8192
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt2xl-flash-rotary.yaml
|
||||
- /experiment/pile/gpt3xl-flash-rotary.yaml
|
||||
|
||||
trainer:
|
||||
max_steps: 60000
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt2xl-flash-8k.yaml
|
||||
- /experiment/pile/gpt3xl-flash-8k.yaml
|
||||
|
||||
model:
|
||||
config:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt2xl-flash.yaml
|
||||
- /experiment/pile/gpt3xl-flash.yaml
|
||||
|
||||
model:
|
||||
config:
|
||||
|
||||
@ -10,7 +10,7 @@ model:
|
||||
n_layer: 24
|
||||
|
||||
datamodule:
|
||||
batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu} < 80 else 8))"}
|
||||
batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))"}
|
||||
|
||||
train:
|
||||
global_batch_size: 512
|
||||
|
||||
35
training/configs/experiment/pile/gpt3xl-hf.yaml
Normal file
35
training/configs/experiment/pile/gpt3xl-hf.yaml
Normal file
@ -0,0 +1,35 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /experiment/pile/gpt3s-hf.yaml
|
||||
- override /optimizer: adamw-zero
|
||||
|
||||
model:
|
||||
config:
|
||||
n_embd: 2048
|
||||
n_head: 16
|
||||
n_layer: 24
|
||||
|
||||
datamodule:
|
||||
batch_size: 2
|
||||
|
||||
train:
|
||||
global_batch_size: 512
|
||||
optimizer:
|
||||
lr: 2.0e-4
|
||||
scheduler:
|
||||
t_initial: 300000
|
||||
|
||||
trainer:
|
||||
strategy:
|
||||
_target_: src.utils.ddp_zero1.DDPStrategyZero1
|
||||
find_unused_parameters: False
|
||||
gradient_as_bucket_view: True
|
||||
max_steps: 400000
|
||||
val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}
|
||||
|
||||
callbacks:
|
||||
model_checkpoint:
|
||||
every_n_train_steps: 1000
|
||||
model_checkpoint_progress:
|
||||
every_n_train_steps: 12500
|
||||
fault_tolerant: False # Saving takes too long
|
||||
Loading…
Reference in New Issue
Block a user