Make sure dout is contiguous
This commit is contained in:
parent
4f285b3547
commit
b4cc152e97
@ -42,7 +42,7 @@ To install:
|
|||||||
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
|
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
|
||||||
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
|
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
|
||||||
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
|
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
|
||||||
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`
|
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
|
||||||
compiling can take a very long time (2h) since it does not use multiple CPU
|
compiling can take a very long time (2h) since it does not use multiple CPU
|
||||||
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
|
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
|
||||||
4. Then:
|
4. Then:
|
||||||
@ -202,10 +202,10 @@ pytest -q -s tests/test_flash_attn.py
|
|||||||
```
|
```
|
||||||
## When you encounter issues
|
## When you encounter issues
|
||||||
|
|
||||||
This new release of FlashAttention-2 have been tested on several GPT-style
|
This new release of FlashAttention-2 has been tested on several GPT-style
|
||||||
models, mostly on A100 GPUs.
|
models, mostly on A100 GPUs.
|
||||||
|
|
||||||
If you encounter any of bugs, please open a respective GitHub Issue!
|
If you encounter bugs, please open a GitHub Issue!
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
If you use this codebase, or otherwise found our work valuable, please cite:
|
If you use this codebase, or otherwise found our work valuable, please cite:
|
||||||
|
|||||||
@ -37,12 +37,8 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
|
|||||||
|
|
||||||
|
|
||||||
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
|
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
|
||||||
if q.stride(-1) != 1:
|
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||||
q = q.contiguous()
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||||
if k.stride(-1) != 1:
|
|
||||||
k = k.contiguous()
|
|
||||||
if v.stride(-1) != 1:
|
|
||||||
v = v.contiguous()
|
|
||||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd(
|
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd(
|
||||||
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
|
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
|
||||||
)
|
)
|
||||||
@ -51,12 +47,8 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softma
|
|||||||
|
|
||||||
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||||
dropout_p, softmax_scale, causal, return_softmax):
|
dropout_p, softmax_scale, causal, return_softmax):
|
||||||
if q.stride(-1) != 1:
|
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||||
q = q.contiguous()
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||||
if k.stride(-1) != 1:
|
|
||||||
k = k.contiguous()
|
|
||||||
if v.stride(-1) != 1:
|
|
||||||
v = v.contiguous()
|
|
||||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd(
|
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd(
|
||||||
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||||
softmax_scale, False, causal, return_softmax, None
|
softmax_scale, False, causal, return_softmax, None
|
||||||
@ -68,6 +60,9 @@ def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q
|
|||||||
|
|
||||||
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||||
dropout_p, softmax_scale, causal):
|
dropout_p, softmax_scale, causal):
|
||||||
|
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||||
|
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||||
|
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||||
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
|
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
|
||||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None
|
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None
|
||||||
)
|
)
|
||||||
@ -77,6 +72,9 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
|||||||
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||||
dropout_p, softmax_scale, causal):
|
dropout_p, softmax_scale, causal):
|
||||||
|
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||||
|
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||||
|
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||||
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
|
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
|
||||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None
|
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user