diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 8c22158..995bea1 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -50,7 +50,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod - def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): + def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, + return_softmax, deterministic): # Save rng_state because the backward pass will regenerate the dropout mask rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: @@ -65,6 +66,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ctx.max_seqlen = max_seqlen ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod @@ -77,18 +79,19 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): _flash_attn_backward( dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens, - ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal + ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal, + num_splits=1 if ctx.deterministic else 0, ) if rng_state is not None: torch.cuda.set_rng_state(cur_rng_state) - return dqkv, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, causal, return_softmax): + softmax_scale, causal, return_softmax, deterministic): # Save rng_state because the backward pass will regenerate the dropout mask rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: @@ -103,6 +106,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod @@ -116,18 +120,19 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): _flash_attn_backward( dout, q, kv[:, 0], kv[:, 1], out, softmax_lse, dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k, - ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal + ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, + num_splits=1 if ctx.deterministic else 0, ) if rng_state is not None: torch.cuda.set_rng_state(cur_rng_state) - return dq, dkv, None, None, None, None, None, None, None, None + return dq, dkv, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, causal, return_softmax): + softmax_scale, causal, return_softmax, deterministic): # Save rng_state because the backward pass will regenerate the dropout mask rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: @@ -142,6 +147,7 @@ class FlashAttnFunc(torch.autograd.Function): ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod @@ -153,18 +159,19 @@ class FlashAttnFunc(torch.autograd.Function): dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal + ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, + num_splits=1 if ctx.deterministic else 0, ) if rng_state is not None: torch.cuda.set_rng_state(cur_rng_state) - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None class FlashAttnQKVPackedSplitFunc(torch.autograd.Function): @staticmethod def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, - softmax_scale, causal, return_softmax): + softmax_scale, causal, return_softmax, deterministic): # Save rng_state because the backward pass will regenerate the dropout mask if dropout_p > 0: rng_state0 = torch.cuda.get_rng_state() @@ -196,6 +203,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function): ctx.batch_size0 = batch_size0 ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.deterministic = deterministic if not return_softmax: return out else: @@ -223,7 +231,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function): dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1], cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p, - ctx.softmax_scale, ctx.causal + ctx.softmax_scale, ctx.causal, num_splits=1 if ctx.deterministic else 0, ) s = torch.cuda.Stream() with torch.cuda.stream(s): @@ -231,16 +239,17 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function): dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:], cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p, - ctx.softmax_scale, ctx.causal, generator=generator1 + ctx.softmax_scale, ctx.causal, generator=generator1, + num_splits=1 if ctx.deterministic else 0, ) torch.cuda.current_stream().wait_stream(s) if rng_state0 is not None: torch.cuda.set_rng_state(cur_rng_state) - return dqkv, None, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None, None def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, - causal=False, return_attn_probs=False): + causal=False, return_attn_probs=False, deterministic=False): """dropout_p should be set to 0.0 during evaluation Arguments: qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. @@ -254,6 +263,7 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + deterministic: bool. Whether or not to ensure deterministic execution. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The @@ -264,12 +274,12 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, - causal, return_attn_probs) + causal, return_attn_probs, deterministic) def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=None, causal=False, - return_attn_probs=False): + return_attn_probs=False, deterministic=False): """dropout_p should be set to 0.0 during evaluation Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. @@ -287,6 +297,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + deterministic: bool. Whether or not to ensure deterministic execution. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The @@ -298,11 +309,12 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq """ return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, - return_attn_probs) + return_attn_probs, deterministic) def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale=None, causal=False, return_attn_probs=False): + dropout_p, softmax_scale=None, causal=False, return_attn_probs=False, + deterministic=False): """dropout_p should be set to 0.0 during evaluation Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. @@ -321,6 +333,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + deterministic: bool. Whether or not to ensure deterministic execution. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The @@ -331,12 +344,12 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal, return_attn_probs) + dropout_p, softmax_scale, causal, return_attn_probs, deterministic) def flash_attn_unpadded_qkvpacked_split_func( qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None, - causal=False, return_attn_probs=False): + causal=False, return_attn_probs=False, deterministic=False): """ Split attention into 2 kernels running on 2 separate streams for performance reason: e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to @@ -358,6 +371,7 @@ def flash_attn_unpadded_qkvpacked_split_func( return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + deterministic: bool. Whether or not to ensure deterministic execution. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The @@ -368,7 +382,8 @@ def flash_attn_unpadded_qkvpacked_split_func( pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, - dropout_p, softmax_scale, causal, return_attn_probs) + dropout_p, softmax_scale, causal, return_attn_probs, + deterministic) def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,