Add option for deterministic execution

This commit is contained in:
Kirthi Shankar Sivamani 2023-03-30 18:23:35 -07:00
parent 009a3e71ec
commit b6aa059bbf

View File

@ -50,7 +50,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal,
return_softmax, deterministic):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
@ -65,6 +66,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
@ -77,18 +79,19 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
_flash_attn_backward(
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
num_splits=1 if ctx.deterministic else 0,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax):
softmax_scale, causal, return_softmax, deterministic):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
@ -103,6 +106,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
@ -116,18 +120,19 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
_flash_attn_backward(
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
num_splits=1 if ctx.deterministic else 0,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dkv, None, None, None, None, None, None, None, None
return dq, dkv, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax):
softmax_scale, causal, return_softmax, deterministic):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
@ -142,6 +147,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
@ -153,18 +159,19 @@ class FlashAttnFunc(torch.autograd.Function):
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
num_splits=1 if ctx.deterministic else 0,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dk, dv, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None
class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
softmax_scale, causal, return_softmax):
softmax_scale, causal, return_softmax, deterministic):
# Save rng_state because the backward pass will regenerate the dropout mask
if dropout_p > 0:
rng_state0 = torch.cuda.get_rng_state()
@ -196,6 +203,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
ctx.batch_size0 = batch_size0
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.deterministic = deterministic
if not return_softmax:
return out
else:
@ -223,7 +231,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1],
cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p,
ctx.softmax_scale, ctx.causal
ctx.softmax_scale, ctx.causal, num_splits=1 if ctx.deterministic else 0,
)
s = torch.cuda.Stream()
with torch.cuda.stream(s):
@ -231,16 +239,17 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:],
cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p,
ctx.softmax_scale, ctx.causal, generator=generator1
ctx.softmax_scale, ctx.causal, generator=generator1,
num_splits=1 if ctx.deterministic else 0,
)
torch.cuda.current_stream().wait_stream(s)
if rng_state0 is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None, None, None
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False):
causal=False, return_attn_probs=False, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
@ -254,6 +263,7 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
@ -264,12 +274,12 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
causal, return_attn_probs)
causal, return_attn_probs, deterministic)
def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale=None, causal=False,
return_attn_probs=False):
return_attn_probs=False, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
@ -287,6 +297,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
@ -298,11 +309,12 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
"""
return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
return_attn_probs)
return_attn_probs, deterministic)
def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale=None, causal=False, return_attn_probs=False):
dropout_p, softmax_scale=None, causal=False, return_attn_probs=False,
deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
@ -321,6 +333,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
@ -331,12 +344,12 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_attn_probs)
dropout_p, softmax_scale, causal, return_attn_probs, deterministic)
def flash_attn_unpadded_qkvpacked_split_func(
qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False):
causal=False, return_attn_probs=False, deterministic=False):
"""
Split attention into 2 kernels running on 2 separate streams for performance reason:
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
@ -358,6 +371,7 @@ def flash_attn_unpadded_qkvpacked_split_func(
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
@ -368,7 +382,8 @@ def flash_attn_unpadded_qkvpacked_split_func(
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0,
dropout_p, softmax_scale, causal, return_attn_probs)
dropout_p, softmax_scale, causal, return_attn_probs,
deterministic)
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,