diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 66325f4..1c732d6 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -43,7 +43,9 @@ def _get_block_size(device, head_dim, is_dropout, is_causal): return (128, 64) if is_sm80 else (64, 64) -def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax): +def _flash_attn_forward( + q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax +): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( @@ -202,7 +204,9 @@ def _flash_attn_varlen_backward( class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod - def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax): + def forward( + ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax + ): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( @@ -322,7 +326,9 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod - def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax): + def forward( + ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax + ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( @@ -452,7 +458,9 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax): + def forward( + ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax + ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( @@ -629,7 +637,7 @@ def flash_attn_kvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, + alibi_slopes=None, return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation @@ -1079,6 +1087,6 @@ def flash_attn_with_kvcache( window_size[1], rotary_interleaved, num_splits, - alibi_slopes + alibi_slopes, ) return out