README syntax highlighting (#365)

* README syntax highlighting

Adds syntax highlighting to README

* Update README.md
This commit is contained in:
Ian Timmis 2023-07-23 03:21:30 -04:00 committed by GitHub
parent 425dbcb6c6
commit cbf982afa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -50,7 +50,7 @@ cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
pip install flash-attn --no-build-isolation
```
Alternatively you can compile from source:
```
```sh
python setup.py install
```
@ -58,7 +58,7 @@ If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might
run too many parallel compilation jobs that could exhaust the amount of RAM. To
limit the number of parallel compilation jobs, you can set the environment
variable `MAX_JOBS`:
```
```sh
MAX_JOBS=4 pip install flash-attn --no-build-isolation
```
@ -76,11 +76,11 @@ FlashAttention-2 currently supports:
The main functions implement scaled dot product attention (softmax(Q @ K^T *
softmax_scale) @ V):
```
```python
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
```
```
```python
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
@ -94,9 +94,10 @@ Arguments:
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
```
```
```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@ -114,6 +115,7 @@ Arguments:
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
```
To see how these functions are used in a multi-head attention layer (which
@ -128,10 +130,10 @@ These functions have been renamed:
If the inputs have the same sequence lengths in the same batch, it is simpler
and faster to use these functions:
```
```python
flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False)
```
```
```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
```
@ -205,7 +207,7 @@ of a baseline implementation in Pytorch (for different head dimensions, input
dtype, sequence length, causal / non-causal).
To run the tests:
```
```sh
pytest -q -s tests/test_flash_attn.py
```
## When you encounter issues