Fix QKV interface to allocate output in Python
This commit is contained in:
parent
5badfb7848
commit
1b9facacc3
@ -47,8 +47,9 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
out, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], torch.empty_like(qkv[:, 0]), cu_seqlens, cu_seqlens,
|
||||
max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal,
|
||||
return_softmax=return_softmax
|
||||
)
|
||||
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
|
||||
Loading…
Reference in New Issue
Block a user