Make sure dout is contiguous

This commit is contained in:
Tri Dao 2023-07-17 21:54:44 -07:00
parent 4f285b3547
commit b4cc152e97
2 changed files with 13 additions and 15 deletions

View File

@ -42,7 +42,7 @@ To install:
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja 3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja --version` then `echo $?` should return exit code 0). If not (sometimes `ninja
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall --version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja` `ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
compiling can take a very long time (2h) since it does not use multiple CPU compiling can take a very long time (2h) since it does not use multiple CPU
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine. cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
4. Then: 4. Then:
@ -202,10 +202,10 @@ pytest -q -s tests/test_flash_attn.py
``` ```
## When you encounter issues ## When you encounter issues
This new release of FlashAttention-2 have been tested on several GPT-style This new release of FlashAttention-2 has been tested on several GPT-style
models, mostly on A100 GPUs. models, mostly on A100 GPUs.
If you encounter any of bugs, please open a respective GitHub Issue! If you encounter bugs, please open a GitHub Issue!
## Citation ## Citation
If you use this codebase, or otherwise found our work valuable, please cite: If you use this codebase, or otherwise found our work valuable, please cite:

View File

@ -37,12 +37,8 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax): def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
if q.stride(-1) != 1: maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q = q.contiguous() q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
if k.stride(-1) != 1:
k = k.contiguous()
if v.stride(-1) != 1:
v = v.contiguous()
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd( out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd(
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
) )
@ -51,12 +47,8 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softma
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_softmax): dropout_p, softmax_scale, causal, return_softmax):
if q.stride(-1) != 1: maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q = q.contiguous() q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
if k.stride(-1) != 1:
k = k.contiguous()
if v.stride(-1) != 1:
v = v.contiguous()
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd( out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd(
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, False, causal, return_softmax, None softmax_scale, False, causal, return_softmax, None
@ -68,6 +60,9 @@ def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
dropout_p, softmax_scale, causal): dropout_p, softmax_scale, causal):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None
) )
@ -77,6 +72,9 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal): dropout_p, softmax_scale, causal):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None