3090 speedup
This commit is contained in:
parent
5a61cb7729
commit
ad6c694bb3
18
README.md
18
README.md
@ -39,7 +39,10 @@ Our tentative roadmap:
|
||||
|
||||
## Speedup and Memory Savings
|
||||
|
||||
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length.
|
||||
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
|
||||
|
||||
### A100
|
||||
|
||||
We display FlashAttention speedup using these parameters (similar to BERT-base):
|
||||
* Batch size 8
|
||||
* Head dimension 64
|
||||
@ -47,14 +50,14 @@ We display FlashAttention speedup using these parameters (similar to BERT-base):
|
||||
|
||||
Our graphs show sequence lengths between 128 and 4096 (when standard attention runs out of memory on an A100), but FlashAttention can scale up to sequence length 64K.
|
||||
|
||||
### Speedup
|
||||
#### Speedup
|
||||
|
||||

|
||||
|
||||
We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels.
|
||||
At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking.
|
||||
|
||||
### Memory
|
||||
#### Memory
|
||||
|
||||

|
||||
|
||||
@ -63,6 +66,15 @@ Memory savings are proportional to sequence length -- since standard attention h
|
||||
We see 10X memory savings at sequence length 2K, and 20X at 4K.
|
||||
As a result, FlashAttention can scale to much longer sequence lengths.
|
||||
|
||||
### RTX 3090
|
||||
|
||||
For the RTX 3090, we use batch size 12 with 12 attention heads.
|
||||
Memory savings are the same as on an A100, so we'll only show speedup here.
|
||||
|
||||

|
||||
|
||||
We see slightly higher speedups (between 2.5-4.5x) on the GTX 3090, since memory bandwidth on the GDDR6X is lower than A100 HBM (~900 GB/s vs. ~1.5 TB/s).
|
||||
|
||||
## Acknowledgments
|
||||
Our implementation uses Apex's
|
||||
[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code
|
||||
|
||||
BIN
assets/flashattn_speedup_3090.jpg
Normal file
BIN
assets/flashattn_speedup_3090.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 106 KiB |
Loading…
Reference in New Issue
Block a user