Add option for deterministic execution
This commit is contained in:
parent
009a3e71ec
commit
b6aa059bbf
@ -50,7 +50,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
|
||||
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
|
||||
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal,
|
||||
return_softmax, deterministic):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
@ -65,6 +66,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
ctx.max_seqlen = max_seqlen
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.deterministic = deterministic
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
@ -77,18 +79,19 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
_flash_attn_backward(
|
||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
|
||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
|
||||
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
num_splits=1 if ctx.deterministic else 0,
|
||||
)
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None, None, None
|
||||
return dqkv, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax):
|
||||
softmax_scale, causal, return_softmax, deterministic):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
@ -103,6 +106,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
ctx.max_seqlen_k = max_seqlen_k
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.deterministic = deterministic
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
@ -116,18 +120,19 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
_flash_attn_backward(
|
||||
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
|
||||
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
num_splits=1 if ctx.deterministic else 0,
|
||||
)
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dkv, None, None, None, None, None, None, None, None
|
||||
return dq, dkv, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax):
|
||||
softmax_scale, causal, return_softmax, deterministic):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
@ -142,6 +147,7 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
ctx.max_seqlen_k = max_seqlen_k
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.deterministic = deterministic
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
@ -153,18 +159,19 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
num_splits=1 if ctx.deterministic else 0,
|
||||
)
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
|
||||
softmax_scale, causal, return_softmax):
|
||||
softmax_scale, causal, return_softmax, deterministic):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
if dropout_p > 0:
|
||||
rng_state0 = torch.cuda.get_rng_state()
|
||||
@ -196,6 +203,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
|
||||
ctx.batch_size0 = batch_size0
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.deterministic = deterministic
|
||||
if not return_softmax:
|
||||
return out
|
||||
else:
|
||||
@ -223,7 +231,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
|
||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0,
|
||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1],
|
||||
cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p,
|
||||
ctx.softmax_scale, ctx.causal
|
||||
ctx.softmax_scale, ctx.causal, num_splits=1 if ctx.deterministic else 0,
|
||||
)
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
@ -231,16 +239,17 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
|
||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1,
|
||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:],
|
||||
cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p,
|
||||
ctx.softmax_scale, ctx.causal, generator=generator1
|
||||
ctx.softmax_scale, ctx.causal, generator=generator1,
|
||||
num_splits=1 if ctx.deterministic else 0,
|
||||
)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
if rng_state0 is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None, None, None, None, None
|
||||
return dqkv, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
|
||||
causal=False, return_attn_probs=False):
|
||||
causal=False, return_attn_probs=False, deterministic=False):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Arguments:
|
||||
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
|
||||
@ -254,6 +263,7 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
deterministic: bool. Whether or not to ensure deterministic execution.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
@ -264,12 +274,12 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
|
||||
causal, return_attn_probs)
|
||||
causal, return_attn_probs, deterministic)
|
||||
|
||||
|
||||
def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
return_attn_probs=False, deterministic=False):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
@ -287,6 +297,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
deterministic: bool. Whether or not to ensure deterministic execution.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
@ -298,11 +309,12 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
|
||||
"""
|
||||
return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
|
||||
return_attn_probs)
|
||||
return_attn_probs, deterministic)
|
||||
|
||||
|
||||
def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale=None, causal=False, return_attn_probs=False):
|
||||
dropout_p, softmax_scale=None, causal=False, return_attn_probs=False,
|
||||
deterministic=False):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
@ -321,6 +333,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
deterministic: bool. Whether or not to ensure deterministic execution.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
@ -331,12 +344,12 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs)
|
||||
dropout_p, softmax_scale, causal, return_attn_probs, deterministic)
|
||||
|
||||
|
||||
def flash_attn_unpadded_qkvpacked_split_func(
|
||||
qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None,
|
||||
causal=False, return_attn_probs=False):
|
||||
causal=False, return_attn_probs=False, deterministic=False):
|
||||
"""
|
||||
Split attention into 2 kernels running on 2 separate streams for performance reason:
|
||||
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
|
||||
@ -358,6 +371,7 @@ def flash_attn_unpadded_qkvpacked_split_func(
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
deterministic: bool. Whether or not to ensure deterministic execution.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
@ -368,7 +382,8 @@ def flash_attn_unpadded_qkvpacked_split_func(
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs)
|
||||
dropout_p, softmax_scale, causal, return_attn_probs,
|
||||
deterministic)
|
||||
|
||||
|
||||
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user