Format flash_attn_interface.py
This commit is contained in:
parent
0a146185d6
commit
bc28eacc60
@ -43,7 +43,9 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
|
|||||||
return (128, 64) if is_sm80 else (64, 64)
|
return (128, 64) if is_sm80 else (64, 64)
|
||||||
|
|
||||||
|
|
||||||
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
|
def _flash_attn_forward(
|
||||||
|
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
|
||||||
|
):
|
||||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
|
||||||
@ -202,7 +204,9 @@ def _flash_attn_varlen_backward(
|
|||||||
|
|
||||||
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
|
def forward(
|
||||||
|
ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
|
||||||
|
):
|
||||||
if softmax_scale is None:
|
if softmax_scale is None:
|
||||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
||||||
@ -322,7 +326,9 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
|||||||
|
|
||||||
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
|
def forward(
|
||||||
|
ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
|
||||||
|
):
|
||||||
if softmax_scale is None:
|
if softmax_scale is None:
|
||||||
softmax_scale = q.shape[-1] ** (-0.5)
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
||||||
@ -452,7 +458,9 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
|||||||
|
|
||||||
class FlashAttnFunc(torch.autograd.Function):
|
class FlashAttnFunc(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
|
def forward(
|
||||||
|
ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
|
||||||
|
):
|
||||||
if softmax_scale is None:
|
if softmax_scale is None:
|
||||||
softmax_scale = q.shape[-1] ** (-0.5)
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
||||||
@ -629,7 +637,7 @@ def flash_attn_kvpacked_func(
|
|||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=False,
|
causal=False,
|
||||||
window_size=(-1, -1), # -1 means infinite context window
|
window_size=(-1, -1), # -1 means infinite context window
|
||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
return_attn_probs=False,
|
return_attn_probs=False,
|
||||||
):
|
):
|
||||||
"""dropout_p should be set to 0.0 during evaluation
|
"""dropout_p should be set to 0.0 during evaluation
|
||||||
@ -1079,6 +1087,6 @@ def flash_attn_with_kvcache(
|
|||||||
window_size[1],
|
window_size[1],
|
||||||
rotary_interleaved,
|
rotary_interleaved,
|
||||||
num_splits,
|
num_splits,
|
||||||
alibi_slopes
|
alibi_slopes,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user