Make sure dout is contiguous
This commit is contained in:
parent
4f285b3547
commit
b4cc152e97
@ -42,7 +42,7 @@ To install:
|
||||
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
|
||||
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
|
||||
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
|
||||
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`
|
||||
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
|
||||
compiling can take a very long time (2h) since it does not use multiple CPU
|
||||
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
|
||||
4. Then:
|
||||
@ -202,10 +202,10 @@ pytest -q -s tests/test_flash_attn.py
|
||||
```
|
||||
## When you encounter issues
|
||||
|
||||
This new release of FlashAttention-2 have been tested on several GPT-style
|
||||
This new release of FlashAttention-2 has been tested on several GPT-style
|
||||
models, mostly on A100 GPUs.
|
||||
|
||||
If you encounter any of bugs, please open a respective GitHub Issue!
|
||||
If you encounter bugs, please open a GitHub Issue!
|
||||
|
||||
## Citation
|
||||
If you use this codebase, or otherwise found our work valuable, please cite:
|
||||
|
||||
@ -37,12 +37,8 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
|
||||
|
||||
|
||||
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
|
||||
if q.stride(-1) != 1:
|
||||
q = q.contiguous()
|
||||
if k.stride(-1) != 1:
|
||||
k = k.contiguous()
|
||||
if v.stride(-1) != 1:
|
||||
v = v.contiguous()
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd(
|
||||
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
|
||||
)
|
||||
@ -51,12 +47,8 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softma
|
||||
|
||||
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_softmax):
|
||||
if q.stride(-1) != 1:
|
||||
q = q.contiguous()
|
||||
if k.stride(-1) != 1:
|
||||
k = k.contiguous()
|
||||
if v.stride(-1) != 1:
|
||||
v = v.contiguous()
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd(
|
||||
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, False, causal, return_softmax, None
|
||||
@ -68,6 +60,9 @@ def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q
|
||||
|
||||
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
dropout_p, softmax_scale, causal):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None
|
||||
)
|
||||
@ -77,6 +72,9 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user