Add option for deterministic execution
This commit is contained in:
parent
009a3e71ec
commit
b6aa059bbf
@ -50,7 +50,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
|
|||||||
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
|
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal,
|
||||||
|
return_softmax, deterministic):
|
||||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||||
if softmax_scale is None:
|
if softmax_scale is None:
|
||||||
@ -65,6 +66,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|||||||
ctx.max_seqlen = max_seqlen
|
ctx.max_seqlen = max_seqlen
|
||||||
ctx.softmax_scale = softmax_scale
|
ctx.softmax_scale = softmax_scale
|
||||||
ctx.causal = causal
|
ctx.causal = causal
|
||||||
|
ctx.deterministic = deterministic
|
||||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -77,18 +79,19 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|||||||
_flash_attn_backward(
|
_flash_attn_backward(
|
||||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
|
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
|
||||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
|
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
|
||||||
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||||
|
num_splits=1 if ctx.deterministic else 0,
|
||||||
)
|
)
|
||||||
if rng_state is not None:
|
if rng_state is not None:
|
||||||
torch.cuda.set_rng_state(cur_rng_state)
|
torch.cuda.set_rng_state(cur_rng_state)
|
||||||
return dqkv, None, None, None, None, None, None
|
return dqkv, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||||
softmax_scale, causal, return_softmax):
|
softmax_scale, causal, return_softmax, deterministic):
|
||||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||||
if softmax_scale is None:
|
if softmax_scale is None:
|
||||||
@ -103,6 +106,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
|||||||
ctx.max_seqlen_k = max_seqlen_k
|
ctx.max_seqlen_k = max_seqlen_k
|
||||||
ctx.softmax_scale = softmax_scale
|
ctx.softmax_scale = softmax_scale
|
||||||
ctx.causal = causal
|
ctx.causal = causal
|
||||||
|
ctx.deterministic = deterministic
|
||||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -116,18 +120,19 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
|||||||
_flash_attn_backward(
|
_flash_attn_backward(
|
||||||
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
|
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
|
||||||
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
|
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
|
||||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||||
|
num_splits=1 if ctx.deterministic else 0,
|
||||||
)
|
)
|
||||||
if rng_state is not None:
|
if rng_state is not None:
|
||||||
torch.cuda.set_rng_state(cur_rng_state)
|
torch.cuda.set_rng_state(cur_rng_state)
|
||||||
return dq, dkv, None, None, None, None, None, None, None, None
|
return dq, dkv, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class FlashAttnFunc(torch.autograd.Function):
|
class FlashAttnFunc(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||||
softmax_scale, causal, return_softmax):
|
softmax_scale, causal, return_softmax, deterministic):
|
||||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||||
if softmax_scale is None:
|
if softmax_scale is None:
|
||||||
@ -142,6 +147,7 @@ class FlashAttnFunc(torch.autograd.Function):
|
|||||||
ctx.max_seqlen_k = max_seqlen_k
|
ctx.max_seqlen_k = max_seqlen_k
|
||||||
ctx.softmax_scale = softmax_scale
|
ctx.softmax_scale = softmax_scale
|
||||||
ctx.causal = causal
|
ctx.causal = causal
|
||||||
|
ctx.deterministic = deterministic
|
||||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -153,18 +159,19 @@ class FlashAttnFunc(torch.autograd.Function):
|
|||||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||||
_flash_attn_backward(
|
_flash_attn_backward(
|
||||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||||
|
num_splits=1 if ctx.deterministic else 0,
|
||||||
)
|
)
|
||||||
if rng_state is not None:
|
if rng_state is not None:
|
||||||
torch.cuda.set_rng_state(cur_rng_state)
|
torch.cuda.set_rng_state(cur_rng_state)
|
||||||
return dq, dk, dv, None, None, None, None, None, None, None, None
|
return dq, dk, dv, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
|
class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
|
def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
|
||||||
softmax_scale, causal, return_softmax):
|
softmax_scale, causal, return_softmax, deterministic):
|
||||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||||
if dropout_p > 0:
|
if dropout_p > 0:
|
||||||
rng_state0 = torch.cuda.get_rng_state()
|
rng_state0 = torch.cuda.get_rng_state()
|
||||||
@ -196,6 +203,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
|
|||||||
ctx.batch_size0 = batch_size0
|
ctx.batch_size0 = batch_size0
|
||||||
ctx.softmax_scale = softmax_scale
|
ctx.softmax_scale = softmax_scale
|
||||||
ctx.causal = causal
|
ctx.causal = causal
|
||||||
|
ctx.deterministic = deterministic
|
||||||
if not return_softmax:
|
if not return_softmax:
|
||||||
return out
|
return out
|
||||||
else:
|
else:
|
||||||
@ -223,7 +231,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
|
|||||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0,
|
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0,
|
||||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1],
|
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1],
|
||||||
cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p,
|
cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p,
|
||||||
ctx.softmax_scale, ctx.causal
|
ctx.softmax_scale, ctx.causal, num_splits=1 if ctx.deterministic else 0,
|
||||||
)
|
)
|
||||||
s = torch.cuda.Stream()
|
s = torch.cuda.Stream()
|
||||||
with torch.cuda.stream(s):
|
with torch.cuda.stream(s):
|
||||||
@ -231,16 +239,17 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
|
|||||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1,
|
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1,
|
||||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:],
|
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:],
|
||||||
cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p,
|
cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p,
|
||||||
ctx.softmax_scale, ctx.causal, generator=generator1
|
ctx.softmax_scale, ctx.causal, generator=generator1,
|
||||||
|
num_splits=1 if ctx.deterministic else 0,
|
||||||
)
|
)
|
||||||
torch.cuda.current_stream().wait_stream(s)
|
torch.cuda.current_stream().wait_stream(s)
|
||||||
if rng_state0 is not None:
|
if rng_state0 is not None:
|
||||||
torch.cuda.set_rng_state(cur_rng_state)
|
torch.cuda.set_rng_state(cur_rng_state)
|
||||||
return dqkv, None, None, None, None, None, None, None, None
|
return dqkv, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
|
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
|
||||||
causal=False, return_attn_probs=False):
|
causal=False, return_attn_probs=False, deterministic=False):
|
||||||
"""dropout_p should be set to 0.0 during evaluation
|
"""dropout_p should be set to 0.0 during evaluation
|
||||||
Arguments:
|
Arguments:
|
||||||
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
|
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
|
||||||
@ -254,6 +263,7 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
|
|||||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||||
testing only. The returned probabilities are not guaranteed to be correct
|
testing only. The returned probabilities are not guaranteed to be correct
|
||||||
(they might not have the right scaling).
|
(they might not have the right scaling).
|
||||||
|
deterministic: bool. Whether or not to ensure deterministic execution.
|
||||||
Return:
|
Return:
|
||||||
out: (total, nheads, headdim).
|
out: (total, nheads, headdim).
|
||||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||||
@ -264,12 +274,12 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
|
|||||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||||
"""
|
"""
|
||||||
return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
|
return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
|
||||||
causal, return_attn_probs)
|
causal, return_attn_probs, deterministic)
|
||||||
|
|
||||||
|
|
||||||
def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||||
dropout_p, softmax_scale=None, causal=False,
|
dropout_p, softmax_scale=None, causal=False,
|
||||||
return_attn_probs=False):
|
return_attn_probs=False, deterministic=False):
|
||||||
"""dropout_p should be set to 0.0 during evaluation
|
"""dropout_p should be set to 0.0 during evaluation
|
||||||
Arguments:
|
Arguments:
|
||||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||||
@ -287,6 +297,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
|
|||||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||||
testing only. The returned probabilities are not guaranteed to be correct
|
testing only. The returned probabilities are not guaranteed to be correct
|
||||||
(they might not have the right scaling).
|
(they might not have the right scaling).
|
||||||
|
deterministic: bool. Whether or not to ensure deterministic execution.
|
||||||
Return:
|
Return:
|
||||||
out: (total, nheads, headdim).
|
out: (total, nheads, headdim).
|
||||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||||
@ -298,11 +309,12 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
|
|||||||
"""
|
"""
|
||||||
return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k,
|
return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k,
|
||||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
|
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
|
||||||
return_attn_probs)
|
return_attn_probs, deterministic)
|
||||||
|
|
||||||
|
|
||||||
def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||||
dropout_p, softmax_scale=None, causal=False, return_attn_probs=False):
|
dropout_p, softmax_scale=None, causal=False, return_attn_probs=False,
|
||||||
|
deterministic=False):
|
||||||
"""dropout_p should be set to 0.0 during evaluation
|
"""dropout_p should be set to 0.0 during evaluation
|
||||||
Arguments:
|
Arguments:
|
||||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||||
@ -321,6 +333,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
|||||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||||
testing only. The returned probabilities are not guaranteed to be correct
|
testing only. The returned probabilities are not guaranteed to be correct
|
||||||
(they might not have the right scaling).
|
(they might not have the right scaling).
|
||||||
|
deterministic: bool. Whether or not to ensure deterministic execution.
|
||||||
Return:
|
Return:
|
||||||
out: (total, nheads, headdim).
|
out: (total, nheads, headdim).
|
||||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||||
@ -331,12 +344,12 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
|||||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||||
"""
|
"""
|
||||||
return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||||
dropout_p, softmax_scale, causal, return_attn_probs)
|
dropout_p, softmax_scale, causal, return_attn_probs, deterministic)
|
||||||
|
|
||||||
|
|
||||||
def flash_attn_unpadded_qkvpacked_split_func(
|
def flash_attn_unpadded_qkvpacked_split_func(
|
||||||
qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None,
|
qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None,
|
||||||
causal=False, return_attn_probs=False):
|
causal=False, return_attn_probs=False, deterministic=False):
|
||||||
"""
|
"""
|
||||||
Split attention into 2 kernels running on 2 separate streams for performance reason:
|
Split attention into 2 kernels running on 2 separate streams for performance reason:
|
||||||
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
|
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
|
||||||
@ -358,6 +371,7 @@ def flash_attn_unpadded_qkvpacked_split_func(
|
|||||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||||
testing only. The returned probabilities are not guaranteed to be correct
|
testing only. The returned probabilities are not guaranteed to be correct
|
||||||
(they might not have the right scaling).
|
(they might not have the right scaling).
|
||||||
|
deterministic: bool. Whether or not to ensure deterministic execution.
|
||||||
Return:
|
Return:
|
||||||
out: (total, nheads, headdim).
|
out: (total, nheads, headdim).
|
||||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||||
@ -368,7 +382,8 @@ def flash_attn_unpadded_qkvpacked_split_func(
|
|||||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||||
"""
|
"""
|
||||||
return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0,
|
return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0,
|
||||||
dropout_p, softmax_scale, causal, return_attn_probs)
|
dropout_p, softmax_scale, causal, return_attn_probs,
|
||||||
|
deterministic)
|
||||||
|
|
||||||
|
|
||||||
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
|
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user