diff --git a/README.md b/README.md index 6baa06c..d2c553e 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,30 @@ Our roadmap to broaden the support: 7. Fused rotary embedding. 8. Attention linear bias (e.g. ALiBi). +### Speedup and Memory Savings + +We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length. +We display FlashAttention speedup using these parameters (similar to BERT-base): +* Batch size 8 +* Head dimension 64 +* 12 attention heads +Our graphs show sequence lengths between 128 and 4096 (when standard attention runs out of memory on an A100), but FlashAttention can scale up to sequence length 64K. + +#### Speedup + +![FlashAttention speedup](images/flashattn_speedup.png) + +We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels. +At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking. + +#### Memory + +![FlashAttention memory](images/flashattn_memory.png) + +We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking). +Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. +We see 10X memory savings at sequence length 2K, and 20X at 4K. +As a result, FlashAttention can scale to much longer sequence lengths. ### Acknowledgments Our implementation uses Apex's diff --git a/images/flashattn_memory.png b/images/flashattn_memory.png new file mode 100644 index 0000000..42a82ec Binary files /dev/null and b/images/flashattn_memory.png differ diff --git a/images/flashattn_speedup.png b/images/flashattn_speedup.png new file mode 100644 index 0000000..3b1a895 Binary files /dev/null and b/images/flashattn_speedup.png differ