Make sure dout is contiguous

This commit is contained in:
Tri Dao 2023-07-17 21:54:44 -07:00
parent 4f285b3547
commit b4cc152e97
2 changed files with 13 additions and 15 deletions

View File

@ -42,7 +42,7 @@ To install:
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
compiling can take a very long time (2h) since it does not use multiple CPU
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
4. Then:
@ -202,10 +202,10 @@ pytest -q -s tests/test_flash_attn.py
```
## When you encounter issues
This new release of FlashAttention-2 have been tested on several GPT-style
This new release of FlashAttention-2 has been tested on several GPT-style
models, mostly on A100 GPUs.
If you encounter any of bugs, please open a respective GitHub Issue!
If you encounter bugs, please open a GitHub Issue!
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:

View File

@ -37,12 +37,8 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
if q.stride(-1) != 1:
q = q.contiguous()
if k.stride(-1) != 1:
k = k.contiguous()
if v.stride(-1) != 1:
v = v.contiguous()
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd(
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
)
@ -51,12 +47,8 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softma
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_softmax):
if q.stride(-1) != 1:
q = q.contiguous()
if k.stride(-1) != 1:
k = k.contiguous()
if v.stride(-1) != 1:
v = v.contiguous()
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd(
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, False, causal, return_softmax, None
@ -68,6 +60,9 @@ def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
dropout_p, softmax_scale, causal):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None
)
@ -77,6 +72,9 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None