Fix the case when dout is not contiguous
This commit is contained in:
parent
a1a5d2ee49
commit
88c4e5dbf6
@ -38,6 +38,7 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
|
||||
as num_splits=3), so effectively the choices are 0, 1, and 2.
|
||||
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
|
||||
"""
|
||||
dout = dout.contiguous() # CUDA code assumes that dout is contiguous
|
||||
_, _, _, softmax_d = flash_attn_cuda.bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user