Add speedup to README

Update images

Update images

Update description
This commit is contained in:
Dan Fu 2022-05-27 21:59:09 +01:00
parent 9dbc491aa5
commit dc6d130088
3 changed files with 24 additions and 0 deletions

View File

@ -28,6 +28,30 @@ Our roadmap to broaden the support:
7. Fused rotary embedding.
8. Attention linear bias (e.g. ALiBi).
### Speedup and Memory Savings
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length.
We display FlashAttention speedup using these parameters (similar to BERT-base):
* Batch size 8
* Head dimension 64
* 12 attention heads
Our graphs show sequence lengths between 128 and 4096 (when standard attention runs out of memory on an A100), but FlashAttention can scale up to sequence length 64K.
#### Speedup
![FlashAttention speedup](images/flashattn_speedup.png)
We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels.
At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking.
#### Memory
![FlashAttention memory](images/flashattn_memory.png)
We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking).
Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length.
We see 10X memory savings at sequence length 2K, and 20X at 4K.
As a result, FlashAttention can scale to much longer sequence lengths.
### Acknowledgments
Our implementation uses Apex's

BIN
images/flashattn_memory.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB