README syntax highlighting (#365)
* README syntax highlighting Adds syntax highlighting to README * Update README.md
This commit is contained in:
parent
425dbcb6c6
commit
cbf982afa5
18
README.md
18
README.md
@ -50,7 +50,7 @@ cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
|
|||||||
pip install flash-attn --no-build-isolation
|
pip install flash-attn --no-build-isolation
|
||||||
```
|
```
|
||||||
Alternatively you can compile from source:
|
Alternatively you can compile from source:
|
||||||
```
|
```sh
|
||||||
python setup.py install
|
python setup.py install
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might
|
|||||||
run too many parallel compilation jobs that could exhaust the amount of RAM. To
|
run too many parallel compilation jobs that could exhaust the amount of RAM. To
|
||||||
limit the number of parallel compilation jobs, you can set the environment
|
limit the number of parallel compilation jobs, you can set the environment
|
||||||
variable `MAX_JOBS`:
|
variable `MAX_JOBS`:
|
||||||
```
|
```sh
|
||||||
MAX_JOBS=4 pip install flash-attn --no-build-isolation
|
MAX_JOBS=4 pip install flash-attn --no-build-isolation
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -76,11 +76,11 @@ FlashAttention-2 currently supports:
|
|||||||
|
|
||||||
The main functions implement scaled dot product attention (softmax(Q @ K^T *
|
The main functions implement scaled dot product attention (softmax(Q @ K^T *
|
||||||
softmax_scale) @ V):
|
softmax_scale) @ V):
|
||||||
```
|
```python
|
||||||
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```python
|
||||||
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False):
|
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False):
|
||||||
"""dropout_p should be set to 0.0 during evaluation
|
"""dropout_p should be set to 0.0 during evaluation
|
||||||
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
||||||
@ -94,9 +94,10 @@ Arguments:
|
|||||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||||
Return:
|
Return:
|
||||||
out: (batch_size, seqlen, nheads, headdim).
|
out: (batch_size, seqlen, nheads, headdim).
|
||||||
|
"""
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```python
|
||||||
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
|
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
|
||||||
"""dropout_p should be set to 0.0 during evaluation
|
"""dropout_p should be set to 0.0 during evaluation
|
||||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||||
@ -114,6 +115,7 @@ Arguments:
|
|||||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||||
Return:
|
Return:
|
||||||
out: (batch_size, seqlen, nheads, headdim).
|
out: (batch_size, seqlen, nheads, headdim).
|
||||||
|
"""
|
||||||
```
|
```
|
||||||
|
|
||||||
To see how these functions are used in a multi-head attention layer (which
|
To see how these functions are used in a multi-head attention layer (which
|
||||||
@ -128,10 +130,10 @@ These functions have been renamed:
|
|||||||
|
|
||||||
If the inputs have the same sequence lengths in the same batch, it is simpler
|
If the inputs have the same sequence lengths in the same batch, it is simpler
|
||||||
and faster to use these functions:
|
and faster to use these functions:
|
||||||
```
|
```python
|
||||||
flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False)
|
flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False)
|
||||||
```
|
```
|
||||||
```
|
```python
|
||||||
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
|
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -205,7 +207,7 @@ of a baseline implementation in Pytorch (for different head dimensions, input
|
|||||||
dtype, sequence length, causal / non-causal).
|
dtype, sequence length, causal / non-causal).
|
||||||
|
|
||||||
To run the tests:
|
To run the tests:
|
||||||
```
|
```sh
|
||||||
pytest -q -s tests/test_flash_attn.py
|
pytest -q -s tests/test_flash_attn.py
|
||||||
```
|
```
|
||||||
## When you encounter issues
|
## When you encounter issues
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user