Add option for deterministic execution

This commit is contained in:
Kirthi Shankar Sivamani 2023-03-30 18:23:35 -07:00
parent 009a3e71ec
commit b6aa059bbf

View File

@ -50,7 +50,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
class FlashAttnQKVPackedFunc(torch.autograd.Function): class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal,
return_softmax, deterministic):
# Save rng_state because the backward pass will regenerate the dropout mask # Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
@ -65,6 +66,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.max_seqlen = max_seqlen ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
@ -77,18 +79,19 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
_flash_attn_backward( _flash_attn_backward(
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse, dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
num_splits=1 if ctx.deterministic else 0,
) )
if rng_state is not None: if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state) torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax): softmax_scale, causal, return_softmax, deterministic):
# Save rng_state because the backward pass will regenerate the dropout mask # Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
@ -103,6 +106,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.max_seqlen_k = max_seqlen_k ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
@ -116,18 +120,19 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
_flash_attn_backward( _flash_attn_backward(
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse, dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k, dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
num_splits=1 if ctx.deterministic else 0,
) )
if rng_state is not None: if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state) torch.cuda.set_rng_state(cur_rng_state)
return dq, dkv, None, None, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax): softmax_scale, causal, return_softmax, deterministic):
# Save rng_state because the backward pass will regenerate the dropout mask # Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
@ -142,6 +147,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.max_seqlen_k = max_seqlen_k ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
@ -153,18 +159,19 @@ class FlashAttnFunc(torch.autograd.Function):
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward( _flash_attn_backward(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
num_splits=1 if ctx.deterministic else 0,
) )
if rng_state is not None: if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state) torch.cuda.set_rng_state(cur_rng_state)
return dq, dk, dv, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None, None
class FlashAttnQKVPackedSplitFunc(torch.autograd.Function): class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
softmax_scale, causal, return_softmax): softmax_scale, causal, return_softmax, deterministic):
# Save rng_state because the backward pass will regenerate the dropout mask # Save rng_state because the backward pass will regenerate the dropout mask
if dropout_p > 0: if dropout_p > 0:
rng_state0 = torch.cuda.get_rng_state() rng_state0 = torch.cuda.get_rng_state()
@ -196,6 +203,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
ctx.batch_size0 = batch_size0 ctx.batch_size0 = batch_size0
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.deterministic = deterministic
if not return_softmax: if not return_softmax:
return out return out
else: else:
@ -223,7 +231,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0, dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1], dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1],
cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p, cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p,
ctx.softmax_scale, ctx.causal ctx.softmax_scale, ctx.causal, num_splits=1 if ctx.deterministic else 0,
) )
s = torch.cuda.Stream() s = torch.cuda.Stream()
with torch.cuda.stream(s): with torch.cuda.stream(s):
@ -231,16 +239,17 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1, dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:], dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:],
cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p, cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p,
ctx.softmax_scale, ctx.causal, generator=generator1 ctx.softmax_scale, ctx.causal, generator=generator1,
num_splits=1 if ctx.deterministic else 0,
) )
torch.cuda.current_stream().wait_stream(s) torch.cuda.current_stream().wait_stream(s)
if rng_state0 is not None: if rng_state0 is not None:
torch.cuda.set_rng_state(cur_rng_state) torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None, None, None
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False): causal=False, return_attn_probs=False, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Arguments: Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
@ -254,6 +263,7 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return: Return:
out: (total, nheads, headdim). out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
@ -264,12 +274,12 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
causal, return_attn_probs) causal, return_attn_probs, deterministic)
def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale=None, causal=False, dropout_p, softmax_scale=None, causal=False,
return_attn_probs=False): return_attn_probs=False, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Arguments: Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
@ -287,6 +297,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return: Return:
out: (total, nheads, headdim). out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
@ -298,11 +309,12 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
""" """
return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k, return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
return_attn_probs) return_attn_probs, deterministic)
def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale=None, causal=False, return_attn_probs=False): dropout_p, softmax_scale=None, causal=False, return_attn_probs=False,
deterministic=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Arguments: Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
@ -321,6 +333,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return: Return:
out: (total, nheads, headdim). out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
@ -331,12 +344,12 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_attn_probs) dropout_p, softmax_scale, causal, return_attn_probs, deterministic)
def flash_attn_unpadded_qkvpacked_split_func( def flash_attn_unpadded_qkvpacked_split_func(
qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False): causal=False, return_attn_probs=False, deterministic=False):
""" """
Split attention into 2 kernels running on 2 separate streams for performance reason: Split attention into 2 kernels running on 2 separate streams for performance reason:
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
@ -358,6 +371,7 @@ def flash_attn_unpadded_qkvpacked_split_func(
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return: Return:
out: (total, nheads, headdim). out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
@ -368,7 +382,8 @@ def flash_attn_unpadded_qkvpacked_split_func(
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0,
dropout_p, softmax_scale, causal, return_attn_probs) dropout_p, softmax_scale, causal, return_attn_probs,
deterministic)
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False, def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,