diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index df914dd..8c22158 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -38,6 +38,7 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens as num_splits=3), so effectively the choices are 0, 1, and 2. This hyperparameter can be tuned for performance, but default value (heuristic) should work fine. """ + dout = dout.contiguous() # CUDA code assumes that dout is contiguous _, _, _, softmax_d = flash_attn_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator)