From c9a649805bccc4adb6df4f22f1c12562501070f0 Mon Sep 17 00:00:00 2001 From: Caleb Thomas <36911613+calebthomas259@users.noreply.github.com> Date: Tue, 27 Dec 2022 14:13:59 +0800 Subject: [PATCH] Add a simple tutorial to README.md --- README.md | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/README.md b/README.md index 4e4cf28..b31de49 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,69 @@ Our tentative roadmap: 9. ~~[Aug 2022] Fuse rotary embedding~~[Done]. 10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding). + +## How to use FlashAttention + +Here's a simple example: +```python +import torch +from flash_attn.flash_attention import FlashMHA + +# Replace this with your correct GPU device +device = "cuda:0" + +# Create attention layer. This is similar to torch.nn.MultiheadAttention, +# and it includes the input and output linear layers +flash_mha = FlashMHA( + embed_dim=128, # total channels (= num_heads * head_dim) + num_heads=8, # number of heads + device=device, + dtype=torch.float16, +) + +# Run forward pass with dummy data +x = torch.randn( + (64, 256, 128), # (batch, seqlen, embed_dim) + device=device, + dtype=torch.float16 +) + +output = flash_mha(x)[0] +``` + +Alternatively, you can import the inner attention layer only (so that the input +and output linear layers are not included): +```python +from flash_attn.flash_attention import FlashAttention + +# Create the nn.Module +flash_attention = FlashAttention() +``` + +Or, if you need more fine-grained control, you can import one of the lower-level +functions (this is more similar to the `torch.nn.functional` style): +```python +from flash_attn.flash_attn_interface import flash_attn_unpadded_func + +# or + +from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func + +# etc. +``` + +There are also separate Python files with various FlashAttention extensions: +```python +# Import the triton implementation (torch.nn.functional version only) +from flash_attn.flash_attn_triton import flash_attn_func + +# Import block sparse attention (nn.Module version) +from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention + +# Import block sparse attention (torch.nn.functional version) +from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func +``` + ## Speedup and Memory Savings We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).