diff --git a/README.md b/README.md index 829230c..e4151a6 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ To install: 3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja --version` then `echo $?` should return exit code 0). If not (sometimes `ninja --version` then `echo $?` returns a nonzero exit code), uninstall then reinstall -`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja` +`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`, compiling can take a very long time (2h) since it does not use multiple CPU cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine. 4. Then: @@ -202,10 +202,10 @@ pytest -q -s tests/test_flash_attn.py ``` ## When you encounter issues -This new release of FlashAttention-2 have been tested on several GPT-style +This new release of FlashAttention-2 has been tested on several GPT-style models, mostly on A100 GPUs. -If you encounter any of bugs, please open a respective GitHub Issue! +If you encounter bugs, please open a GitHub Issue! ## Citation If you use this codebase, or otherwise found our work valuable, please cite: diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index c6e4340..19796a0 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -37,12 +37,8 @@ def _get_block_size(device, head_dim, is_dropout, is_causal): def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax): - if q.stride(-1) != 1: - q = q.contiguous() - if k.stride(-1) != 1: - k = k.contiguous() - if v.stride(-1) != 1: - v = v.contiguous() + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd( q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None ) @@ -51,12 +47,8 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softma def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, return_softmax): - if q.stride(-1) != 1: - q = q.contiguous() - if k.stride(-1) != 1: - k = k.contiguous() - if v.stride(-1) != 1: - v = v.contiguous() + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd( q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, return_softmax, None @@ -68,6 +60,9 @@ def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None ) @@ -77,6 +72,9 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None