Fix the case when dout is not contiguous

This commit is contained in:
Tri Dao 2022-12-13 13:58:17 -08:00
parent a1a5d2ee49
commit 88c4e5dbf6

View File

@ -38,6 +38,7 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
as num_splits=3), so effectively the choices are 0, 1, and 2.
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
"""
dout = dout.contiguous() # CUDA code assumes that dout is contiguous
_, _, _, softmax_d = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator)