232 lines
9.7 KiB
Markdown
232 lines
9.7 KiB
Markdown
# Optimized Transformer implementation
|
|
This repo contains examples of how FlashAttention can be integrated into a model
|
|
(e.g., GPT, ViT) and trained end-to-end. We also provide optimized
|
|
implementations of other layers (e.g., MLP, LayerNorm, cross-entropy loss,
|
|
rotary embedding). Overall this speeds up training by 3-5x compared to the
|
|
baseline implementation from Huggingface, reaching up to 189 TFLOPs/sec per A100,
|
|
equivalent to 60.6\% model FLOPs utilization (we don't need any activation
|
|
checkpointing). All without changing the model architecture (i.e., no
|
|
approximation).
|
|
|
|
Goals:
|
|
- Performance: we optimize for model speed and memory, especially on 1-node
|
|
(e.g., with 8 A100s).
|
|
- Flexibility: we provide optimized building blocks (MLP, attention, LayerNorm),
|
|
and the model code illustrates how these components can be put together.
|
|
The training code also aims to be model- & task-agnostic.
|
|
|
|
Non-goals (and other resources):
|
|
- Support as many models as possible: Huggingface's
|
|
[transformers](https://github.com/huggingface/transformers) and
|
|
[timm](https://github.com/rwightman/pytorch-image-models/) are great for this.
|
|
- Large-scale distributed training: our codebase has been used for multi-GPU and multi-node
|
|
training for models up to 2.7B parameters. However, if you're looking for large-scale distributed
|
|
training techniques (e.g., pipeline parallelism, tensor parallelism),
|
|
check out [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/) and
|
|
[DeepSpeed](https://github.com/microsoft/deepspeed).
|
|
- Inference: we currently focus on training (this might change in the future).
|
|
If you want fast inference, take a look at
|
|
[FasterTransformer](https://github.com/NVIDIA/FasterTransformer).
|
|
- Production: this codebase was written during several research projects to validate ideas
|
|
on speeding up ML models.
|
|
|
|
## Model Components
|
|
|
|
The GPT model is implemented
|
|
[here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
|
|
And here's an example to construct the GPT3-1.3B model with rotary embedding:
|
|
```python
|
|
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
|
from flash_attn.models.gpt import GPTLMHeadModel
|
|
|
|
seqlen = 2048
|
|
hidden_dim = 2048
|
|
nheads = 16
|
|
n_layer = 24
|
|
rotary_emb_fraction = 0.5
|
|
config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,
|
|
n_layer=n_layer, n_head=nheads,
|
|
scale_attn_by_inverse_layer_idx=True,
|
|
rotary_emb_fraction=rotary_emb_fraction,
|
|
use_flash_attn=True, fused_mlp=True,
|
|
fused_bias_fc=True, fused_dropout_add_ln=True,
|
|
pad_vocab_size_multiple=8)
|
|
model = GPTLMHeadModel(config)
|
|
```
|
|
|
|
We provide the following optimized components:
|
|
|
|
1. FlashAttention: fast and memory-efficient exact attention. This makes
|
|
attention much faster and saves a lot of activation memory. As a result we don't need
|
|
to use any activation checkpointing.
|
|
```sh
|
|
pip install flash-attn
|
|
```
|
|
|
|
2. Fused matmul + bias (forward and backward), and fused matmul + bias + gelu
|
|
(forward and backward), adapted from Apex's
|
|
[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We
|
|
make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before
|
|
this doesn't have the best matmul + bias + gelu performance for bfloat16.
|
|
```sh
|
|
cd ../csrc/fused_dense_lib && pip install .
|
|
```
|
|
3. Optimized cross-entropy loss, adapted from Apex's
|
|
[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory.
|
|
```sh
|
|
cd ../csrc/xentropy && pip install .
|
|
```
|
|
4. Fused rotary embedding:
|
|
```sh
|
|
cd ../csrc/rotary && pip install .
|
|
```
|
|
5. Fused dropout + residual + LayerNorm, adapted from Apex's
|
|
[FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
|
|
This supports dimensions divisible by 8, up to 6144.
|
|
```sh
|
|
cd ../csrc/layer_norm && pip install .
|
|
```
|
|
|
|
## Training
|
|
|
|
We also provide here training scripts to train GPT2 on Openwebtext and GPT3 on
|
|
The Pile as examples. Feel free to use the model in your own training setup as
|
|
well.
|
|
|
|
We use [Hydra](https://hydra.cc/) for configuration,
|
|
[Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and
|
|
[Wandb](https://wandb.ai/) for logging.
|
|
|
|
We use the template from `https://github.com/ashleve/lightning-hydra-template`.
|
|
Please read the instructions there to understand the repo structure.
|
|
|
|
### Requirements
|
|
|
|
Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core,
|
|
hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn.
|
|
We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
|
|
|
|
We provide a Dockerfile that lists all the required packages.
|
|
|
|
### Dataset preparation
|
|
|
|
Running the training command would automatically download the datasets
|
|
(Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the
|
|
tokens, then save this cache to disk. Alternatively, you can also prepare the
|
|
datasets as a separate step.
|
|
|
|
The cached datasets are saved to `${DATA_DIR}/openwebtext` and
|
|
`${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to
|
|
`./data/{openwebtext,the_pile}`.
|
|
|
|
- Openwebtext:
|
|
```sh
|
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
|
pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "openwebtext"
|
|
```
|
|
This takes around 1h on a 64-core CPU. The processed dataset has size 17GB.
|
|
|
|
- The Pile:
|
|
```sh
|
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
|
pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "pile"
|
|
```
|
|
This takes around 20h on a 64-core CPU. The processed dataset has size 699GB.
|
|
|
|
### GPT2 training on Openwebtext
|
|
To train GPT2 on Openwebtext with 8 GPUs:
|
|
```sh
|
|
python run.py experiment=owt/gpt2s-flash trainer.devices=8 # 125M
|
|
python run.py experiment=owt/gpt2m-flash trainer.devices=8 # 355M
|
|
python run.py experiment=owt/gpt2l-flash trainer.devices=8 # 760M
|
|
python run.py experiment=owt/gpt2xl-flash trainer.devices=8 # 1.6B
|
|
```
|
|
The default parameters are set for 8 x A100 80GB.
|
|
|
|
To train with bf16 instead of fp16, add `trainer.precision=bf16`.
|
|
|
|
### GPT3 training on The Pile
|
|
To train GPT3 on The Pile with 8 GPUs:
|
|
```sh
|
|
python run.py experiment=pile/gpt3s-flash trainer.devices=8 # 125M
|
|
python run.py experiment=pile/gpt3m-flash trainer.devices=8 # 355M
|
|
python run.py experiment=pile/gpt3l-flash trainer.devices=8 # 760M
|
|
python run.py experiment=pile/gpt3xl-flash trainer.devices=8 # 1.3B
|
|
python run.py experiment=pile/gpt3-2.7B-flash-hdim128 trainer.devices=8 # 2.7B
|
|
```
|
|
The default parameters are set for 8 x A100 80GB. We train with bf16 by default.
|
|
|
|
To train with rotary embedding, run the experiments `pile/gpt3{s,m,l,xl}-flash-rotary`.
|
|
|
|
### Training options
|
|
|
|
**Gradient accumulation**: to adjust device batch size to fit into GPU memory
|
|
(the global batch size stays the same, and gradient accumulation is calculated
|
|
automatically), set `datamodule.batch_size=blah`.
|
|
|
|
**Multi-node**: to train on multiple nodes, add `trainer.num_nodes=blah`.
|
|
|
|
**Speed benchmarking**: to print out iteration time, add `+callbacks.speed_monitor.verbose=True`.
|
|
|
|
**Resumable training**: set a name to the run, and then set `resume=True` when
|
|
you resume. Training will restart at exactly the same batch.
|
|
```sh
|
|
python run.py experiment=pile/gpt3s-flash trainer.devices=8 name=pile-gpt3s-flash resume=True
|
|
```
|
|
|
|
## Training speed
|
|
|
|
We measure the wallclock training speed on one node with 8 x A100 80GB SXM4 80GB (400W) with NVLink.
|
|
|
|
FLOPs are calculated using the formula from the [Megatron-LM
|
|
paper](https://arxiv.org/abs/2104.04473) (Section 5.1), except we scale by 3/4
|
|
to get the model FLOPs (instead of hardware FLOPs with activation
|
|
checkpointing).
|
|
|
|
|
|
### GPT2 (sequence length 1024)
|
|
|
|

|
|
|
|
The implementation in this repo (FlashAttention) is 3-4x faster than the
|
|
baseline implementation from Huggingface.
|
|
|
|
### GPT3 (sequence length 2048)
|
|
|
|

|
|
|
|
The implementation in this repo (FlashAttention) is 3-5x faster than the
|
|
baseline implementation from Huggingface.
|
|
|
|
For the GPT3-2.7B model, we set head dimension to 128 (instead of 80) for better efficiency.
|
|
|
|
We include here more details on the training speed with FlashAttention on 8 x
|
|
A100 80GB.
|
|
|
|
| Model | Batch size (tokens) | Through put (tokens/sec) | Hours / 1B tokens |
|
|
| --------- | ------------------- | ------------------------ | ----------------- |
|
|
| GPT3-125M | 0.5M | 1310k | 0.21 |
|
|
| GPT3-355M | 0.5M | 503k | 0.55 |
|
|
| GPT3-760M | 0.5M | 245k | 1.13 |
|
|
| GPT3-1.3B | 1M | 169k | 1.64 |
|
|
| GPT3-2.7B | 1M | 85k | 3.27 |
|
|
|
|
As an example, this means that one can train a GPT3-1.3B model on 26B tokens
|
|
(compute-optimal according to Chinchilla scaling) in about 43 hours on 8 x A100.
|
|
|
|
## Training quality
|
|
|
|
We include here the loss curve for GPT2 on Openwebtext, trained for 200B tokens.
|
|
For GPT2, the runs with FlashAttention yield the same loss curve as the runs
|
|
with the baseline implementation from Huggingface for 125M and 355M models. For
|
|
larger models the baseline implementation just takes too long.
|
|
|
|

|
|
|
|
We include here the loss curve for GPT3 on The Pile, trained for 400B tokens.
|
|
The 125M, 355M, 760M models have batch size 512k tokens so this translates to
|
|
800k training steps, while the 1.3B and 2.7B models have batch size 1M tokens,
|
|
which translates to 400k training steps.
|
|
|
|

|