Fix QKV interface to allocate output in Python

This commit is contained in:
Tri Dao 2022-10-14 03:33:41 -07:00
parent 5badfb7848
commit 1b9facacc3

View File

@ -47,8 +47,9 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
qkv[:, 0], qkv[:, 1], qkv[:, 2], torch.empty_like(qkv[:, 0]), cu_seqlens, cu_seqlens,
max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal,
return_softmax=return_softmax
)
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p