Add a simple tutorial to README.md

This commit is contained in:
Caleb Thomas 2022-12-27 14:13:59 +08:00
parent 1bc6e5b09c
commit c9a649805b

View File

@ -74,6 +74,69 @@ Our tentative roadmap:
9. ~~[Aug 2022] Fuse rotary embedding~~[Done].
10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).
## How to use FlashAttention
Here's a simple example:
```python
import torch
from flash_attn.flash_attention import FlashMHA
# Replace this with your correct GPU device
device = "cuda:0"
# Create attention layer. This is similar to torch.nn.MultiheadAttention,
# and it includes the input and output linear layers
flash_mha = FlashMHA(
embed_dim=128, # total channels (= num_heads * head_dim)
num_heads=8, # number of heads
device=device,
dtype=torch.float16,
)
# Run forward pass with dummy data
x = torch.randn(
(64, 256, 128), # (batch, seqlen, embed_dim)
device=device,
dtype=torch.float16
)
output = flash_mha(x)[0]
```
Alternatively, you can import the inner attention layer only (so that the input
and output linear layers are not included):
```python
from flash_attn.flash_attention import FlashAttention
# Create the nn.Module
flash_attention = FlashAttention()
```
Or, if you need more fine-grained control, you can import one of the lower-level
functions (this is more similar to the `torch.nn.functional` style):
```python
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
# or
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
# etc.
```
There are also separate Python files with various FlashAttention extensions:
```python
# Import the triton implementation (torch.nn.functional version only)
from flash_attn.flash_attn_triton import flash_attn_func
# Import block sparse attention (nn.Module version)
from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention
# Import block sparse attention (torch.nn.functional version)
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
```
## Speedup and Memory Savings
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).