Add a simple tutorial to README.md
This commit is contained in:
parent
1bc6e5b09c
commit
c9a649805b
63
README.md
63
README.md
@ -74,6 +74,69 @@ Our tentative roadmap:
|
||||
9. ~~[Aug 2022] Fuse rotary embedding~~[Done].
|
||||
10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).
|
||||
|
||||
|
||||
## How to use FlashAttention
|
||||
|
||||
Here's a simple example:
|
||||
```python
|
||||
import torch
|
||||
from flash_attn.flash_attention import FlashMHA
|
||||
|
||||
# Replace this with your correct GPU device
|
||||
device = "cuda:0"
|
||||
|
||||
# Create attention layer. This is similar to torch.nn.MultiheadAttention,
|
||||
# and it includes the input and output linear layers
|
||||
flash_mha = FlashMHA(
|
||||
embed_dim=128, # total channels (= num_heads * head_dim)
|
||||
num_heads=8, # number of heads
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
# Run forward pass with dummy data
|
||||
x = torch.randn(
|
||||
(64, 256, 128), # (batch, seqlen, embed_dim)
|
||||
device=device,
|
||||
dtype=torch.float16
|
||||
)
|
||||
|
||||
output = flash_mha(x)[0]
|
||||
```
|
||||
|
||||
Alternatively, you can import the inner attention layer only (so that the input
|
||||
and output linear layers are not included):
|
||||
```python
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
|
||||
# Create the nn.Module
|
||||
flash_attention = FlashAttention()
|
||||
```
|
||||
|
||||
Or, if you need more fine-grained control, you can import one of the lower-level
|
||||
functions (this is more similar to the `torch.nn.functional` style):
|
||||
```python
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
|
||||
# or
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
|
||||
|
||||
# etc.
|
||||
```
|
||||
|
||||
There are also separate Python files with various FlashAttention extensions:
|
||||
```python
|
||||
# Import the triton implementation (torch.nn.functional version only)
|
||||
from flash_attn.flash_attn_triton import flash_attn_func
|
||||
|
||||
# Import block sparse attention (nn.Module version)
|
||||
from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention
|
||||
|
||||
# Import block sparse attention (torch.nn.functional version)
|
||||
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
|
||||
```
|
||||
|
||||
## Speedup and Memory Savings
|
||||
|
||||
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
|
||||
|
||||
Loading…
Reference in New Issue
Block a user