diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 8e912fe..a6d8b4f 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -47,8 +47,9 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) out, softmax_lse, S_dmask = _flash_attn_forward( - qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, - dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax + qkv[:, 0], qkv[:, 1], qkv[:, 2], torch.empty_like(qkv[:, 0]), cu_seqlens, cu_seqlens, + max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal, + return_softmax=return_softmax ) ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state) ctx.dropout_p = dropout_p