Add change log
This commit is contained in:
parent
aa4fd2d166
commit
3a9fe7b0fa
107
README.md
107
README.md
@ -31,8 +31,7 @@ Please cite and credit FlashAttention if you use it.
|
|||||||
Requirements:
|
Requirements:
|
||||||
- CUDA 11.6 and above.
|
- CUDA 11.6 and above.
|
||||||
- PyTorch 1.12 and above.
|
- PyTorch 1.12 and above.
|
||||||
- Linux. Windows is not supported for now. If you have ideas on how to modify
|
- Linux. Windows is not supported for now. If you have ideas on how to modify the code to support Windows, please reach out via Github issue.
|
||||||
the code to support Windows, please reach out via Github issue.
|
|
||||||
|
|
||||||
We recommend the
|
We recommend the
|
||||||
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
|
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
|
||||||
@ -83,29 +82,35 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
|||||||
```
|
```
|
||||||
|
|
||||||
```python
|
```python
|
||||||
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False):
|
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)):
|
||||||
"""dropout_p should be set to 0.0 during evaluation
|
"""dropout_p should be set to 0.0 during evaluation
|
||||||
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
||||||
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||||
of the gradients of Q, K, V.
|
of the gradients of Q, K, V.
|
||||||
|
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
||||||
|
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
|
||||||
Arguments:
|
Arguments:
|
||||||
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||||
dropout_p: float. Dropout probability.
|
dropout_p: float. Dropout probability.
|
||||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||||
Default to 1 / sqrt(headdim).
|
Default to 1 / sqrt(headdim).
|
||||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||||
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||||
Return:
|
Return:
|
||||||
out: (batch_size, seqlen, nheads, headdim).
|
out: (batch_size, seqlen, nheads, headdim).
|
||||||
"""
|
"""
|
||||||
```
|
```
|
||||||
|
|
||||||
```python
|
```python
|
||||||
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
|
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)):
|
||||||
"""dropout_p should be set to 0.0 during evaluation
|
"""dropout_p should be set to 0.0 during evaluation
|
||||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||||
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
||||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||||
|
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
||||||
|
will only attend to keys between
|
||||||
|
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
q: (batch_size, seqlen, nheads, headdim)
|
q: (batch_size, seqlen, nheads, headdim)
|
||||||
@ -115,15 +120,86 @@ Arguments:
|
|||||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||||
Default to 1 / sqrt(headdim).
|
Default to 1 / sqrt(headdim).
|
||||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||||
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||||
Return:
|
Return:
|
||||||
out: (batch_size, seqlen, nheads, headdim).
|
out: (batch_size, seqlen, nheads, headdim).
|
||||||
"""
|
"""
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
def flash_attn_with_kvcache(
|
||||||
|
q,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
k=None,
|
||||||
|
v=None,
|
||||||
|
rotary_cos=None,
|
||||||
|
rotary_sin=None,
|
||||||
|
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
||||||
|
cache_batch_idx: Optional[torch.Tensor] = None,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=False,
|
||||||
|
window_size=(-1, -1), # -1 means infinite context window
|
||||||
|
rotary_interleaved=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
||||||
|
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
|
||||||
|
the previous step, and update them with the new keys/values from the current step, and do
|
||||||
|
attention with the updated cache, all in 1 kernel.
|
||||||
|
|
||||||
|
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
|
||||||
|
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
||||||
|
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
|
||||||
|
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
||||||
|
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
|
||||||
|
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
|
||||||
|
|
||||||
|
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
|
||||||
|
|
||||||
|
Note: Does not support backward pass.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen, nheads, headdim)
|
||||||
|
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
|
||||||
|
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
|
||||||
|
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
||||||
|
k with k_cache, starting at the indices specified by cache_seqlens.
|
||||||
|
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
|
||||||
|
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
|
||||||
|
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
|
||||||
|
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
||||||
|
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
||||||
|
KV cache.
|
||||||
|
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
|
||||||
|
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
|
||||||
|
If the indices are not distinct, and k and v are provided, the values updated in the cache
|
||||||
|
might come from any of the duplicate indices.
|
||||||
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||||
|
Default to 1 / sqrt(headdim).
|
||||||
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||||
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||||
|
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
||||||
|
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
||||||
|
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
||||||
|
(i.e. GPT-NeoX style).
|
||||||
|
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
|
||||||
|
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
|
||||||
|
to automatically determine the number of splits.
|
||||||
|
Don't change this unless you know what you are doing.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
out: (batch_size, seqlen, nheads, headdim).
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
To see how these functions are used in a multi-head attention layer (which
|
To see how these functions are used in a multi-head attention layer (which
|
||||||
includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).
|
includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).
|
||||||
|
|
||||||
## Upgrading from FlashAttention (1.x) to FlashAttention-2
|
## Changelog
|
||||||
|
|
||||||
|
### 2.0
|
||||||
|
Upgrading from FlashAttention (1.x) to FlashAttention-2
|
||||||
|
|
||||||
These functions have been renamed:
|
These functions have been renamed:
|
||||||
- `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
|
- `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
|
||||||
@ -138,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
|
|||||||
```python
|
```python
|
||||||
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
|
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
|
||||||
```
|
```
|
||||||
## Changes in v2.1 (compared to v2.0)
|
### 2.1
|
||||||
|
|
||||||
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
|
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
|
||||||
bottom right corner of the attention matrix, instead of the top-left corner.
|
bottom right corner of the attention matrix, instead of the top-left corner.
|
||||||
@ -167,6 +243,25 @@ v2.1:
|
|||||||
1 1
|
1 1
|
||||||
If the row of the mask is all zero, the output will be zero.
|
If the row of the mask is all zero, the output will be zero.
|
||||||
|
|
||||||
|
### 2.2
|
||||||
|
|
||||||
|
Optimize for inference (iterative decoding) when query has very small sequence
|
||||||
|
length (e.g., query sequence length = 1). The bottleneck here is to load KV
|
||||||
|
cache as fast as possible, and we split the loading across different thread
|
||||||
|
blocks, with a separate kernel to combine results.
|
||||||
|
|
||||||
|
See the function `flash_attn_with_kvcache` with more features for inference
|
||||||
|
(perform rotary embedding, updating KV cache inplace).
|
||||||
|
|
||||||
|
Thanks to the xformers team, and in particular Daniel Haziza, for this
|
||||||
|
collaboration.
|
||||||
|
|
||||||
|
### 2.3
|
||||||
|
|
||||||
|
Implement sliding window attention (i.e., local attention). Thanks to [Mistral
|
||||||
|
AI](https://mistral.ai/) and in particular Timothée Lacroix for this
|
||||||
|
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
|
||||||
|
|
||||||
## Performance
|
## Performance
|
||||||
|
|
||||||
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
|
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user