README syntax highlighting (#365)

* README syntax highlighting

Adds syntax highlighting to README

* Update README.md
This commit is contained in:
Ian Timmis 2023-07-23 03:21:30 -04:00 committed by GitHub
parent 425dbcb6c6
commit cbf982afa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -50,7 +50,7 @@ cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
pip install flash-attn --no-build-isolation pip install flash-attn --no-build-isolation
``` ```
Alternatively you can compile from source: Alternatively you can compile from source:
``` ```sh
python setup.py install python setup.py install
``` ```
@ -58,7 +58,7 @@ If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might
run too many parallel compilation jobs that could exhaust the amount of RAM. To run too many parallel compilation jobs that could exhaust the amount of RAM. To
limit the number of parallel compilation jobs, you can set the environment limit the number of parallel compilation jobs, you can set the environment
variable `MAX_JOBS`: variable `MAX_JOBS`:
``` ```sh
MAX_JOBS=4 pip install flash-attn --no-build-isolation MAX_JOBS=4 pip install flash-attn --no-build-isolation
``` ```
@ -76,11 +76,11 @@ FlashAttention-2 currently supports:
The main functions implement scaled dot product attention (softmax(Q @ K^T * The main functions implement scaled dot product attention (softmax(Q @ K^T *
softmax_scale) @ V): softmax_scale) @ V):
``` ```python
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
``` ```
``` ```python
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False): flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than If Q, K, V are already stacked into 1 tensor, this function will be faster than
@ -94,9 +94,10 @@ Arguments:
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
"""
``` ```
``` ```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@ -114,6 +115,7 @@ Arguments:
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
"""
``` ```
To see how these functions are used in a multi-head attention layer (which To see how these functions are used in a multi-head attention layer (which
@ -128,10 +130,10 @@ These functions have been renamed:
If the inputs have the same sequence lengths in the same batch, it is simpler If the inputs have the same sequence lengths in the same batch, it is simpler
and faster to use these functions: and faster to use these functions:
``` ```python
flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False) flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False)
``` ```
``` ```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
``` ```
@ -205,7 +207,7 @@ of a baseline implementation in Pytorch (for different head dimensions, input
dtype, sequence length, causal / non-causal). dtype, sequence length, causal / non-causal).
To run the tests: To run the tests:
``` ```sh
pytest -q -s tests/test_flash_attn.py pytest -q -s tests/test_flash_attn.py
``` ```
## When you encounter issues ## When you encounter issues