diff --git a/README.md b/README.md index 536dbe7..88ba4ed 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine. pip install flash-attn --no-build-isolation ``` Alternatively you can compile from source: -``` +```sh python setup.py install ``` @@ -58,7 +58,7 @@ If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might run too many parallel compilation jobs that could exhaust the amount of RAM. To limit the number of parallel compilation jobs, you can set the environment variable `MAX_JOBS`: -``` +```sh MAX_JOBS=4 pip install flash-attn --no-build-isolation ``` @@ -76,11 +76,11 @@ FlashAttention-2 currently supports: The main functions implement scaled dot product attention (softmax(Q @ K^T * softmax_scale) @ V): -``` +```python from flash_attn import flash_attn_qkvpacked_func, flash_attn_func ``` -``` +```python flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than @@ -94,9 +94,10 @@ Arguments: causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). Return: out: (batch_size, seqlen, nheads, headdim). +""" ``` -``` +```python flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -114,6 +115,7 @@ Arguments: causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). Return: out: (batch_size, seqlen, nheads, headdim). +""" ``` To see how these functions are used in a multi-head attention layer (which @@ -128,10 +130,10 @@ These functions have been renamed: If the inputs have the same sequence lengths in the same batch, it is simpler and faster to use these functions: -``` +```python flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False) ``` -``` +```python flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) ``` @@ -205,7 +207,7 @@ of a baseline implementation in Pytorch (for different head dimensions, input dtype, sequence length, causal / non-causal). To run the tests: -``` +```sh pytest -q -s tests/test_flash_attn.py ``` ## When you encounter issues