Add change log

This commit is contained in:
Tri Dao 2023-10-05 14:18:14 -07:00
parent aa4fd2d166
commit 3a9fe7b0fa

107
README.md
View File

@ -31,8 +31,7 @@ Please cite and credit FlashAttention if you use it.
Requirements: Requirements:
- CUDA 11.6 and above. - CUDA 11.6 and above.
- PyTorch 1.12 and above. - PyTorch 1.12 and above.
- Linux. Windows is not supported for now. If you have ideas on how to modify - Linux. Windows is not supported for now. If you have ideas on how to modify the code to support Windows, please reach out via Github issue.
the code to support Windows, please reach out via Github issue.
We recommend the We recommend the
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
@ -83,29 +82,35 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
``` ```
```python ```python
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False): flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V. of the gradients of Q, K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments: Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim) qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability. dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax. softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
""" """
``` ```
```python ```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments: Arguments:
q: (batch_size, seqlen, nheads, headdim) q: (batch_size, seqlen, nheads, headdim)
@ -115,15 +120,86 @@ Arguments:
softmax_scale: float. The scaling of QK^T before applying softmax. softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
""" """
``` ```
```python
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
rotary_interleaved=True,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Note: Does not support backward pass.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
```
To see how these functions are used in a multi-head attention layer (which To see how these functions are used in a multi-head attention layer (which
includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py). includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).
## Upgrading from FlashAttention (1.x) to FlashAttention-2 ## Changelog
### 2.0
Upgrading from FlashAttention (1.x) to FlashAttention-2
These functions have been renamed: These functions have been renamed:
- `flash_attn_unpadded_func` -> `flash_attn_varlen_func` - `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
@ -138,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```python ```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
``` ```
## Changes in v2.1 (compared to v2.0) ### 2.1
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner. bottom right corner of the attention matrix, instead of the top-left corner.
@ -167,6 +243,25 @@ v2.1:
1 1 1 1
If the row of the mask is all zero, the output will be zero. If the row of the mask is all zero, the output will be zero.
### 2.2
Optimize for inference (iterative decoding) when query has very small sequence
length (e.g., query sequence length = 1). The bottleneck here is to load KV
cache as fast as possible, and we split the loading across different thread
blocks, with a separate kernel to combine results.
See the function `flash_attn_with_kvcache` with more features for inference
(perform rotary embedding, updating KV cache inplace).
Thanks to the xformers team, and in particular Daniel Haziza, for this
collaboration.
### 2.3
Implement sliding window attention (i.e., local attention). Thanks to [Mistral
AI](https://mistral.ai/) and in particular Timothée Lacroix for this
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
## Performance ## Performance
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).