[Triton] Fix benchmark_causal, mention Triton version
This commit is contained in:
parent
dc08ea1c33
commit
5d079fdd7a
@ -93,8 +93,8 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b
|
|||||||
benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
|
benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
|
||||||
repeats=repeats, desc='PyTorch Attention')
|
repeats=repeats, desc='PyTorch Attention')
|
||||||
|
|
||||||
benchmark_all(flash_attn_qkvpacked_func, qkv, causal, repeats=repeats, desc='FlashAttention Triton')
|
benchmark_all(flash_attn_qkvpacked_func, qkv, causal=causal, repeats=repeats, desc='FlashAttention Triton')
|
||||||
pytorch_profiler(flash_attn_qkvpacked_func, qkv, causal, backward=True)
|
pytorch_profiler(flash_attn_qkvpacked_func, qkv, causal=causal, backward=True)
|
||||||
|
|
||||||
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
||||||
requires_grad=True) for _ in range(3)]
|
requires_grad=True) for _ in range(3)]
|
||||||
|
|||||||
@ -1,5 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
*Experimental* implementation of FlashAttention in Triton.
|
*Experimental* implementation of FlashAttention in Triton.
|
||||||
|
Tested with triton==2.0.0.dev20221202.
|
||||||
|
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
|
||||||
|
other than 64:
|
||||||
|
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
|
||||||
|
We'll update this implementation with the new Triton backend once this is fixed.
|
||||||
|
|
||||||
We use the FlashAttention implementation from Phil Tillet a starting point.
|
We use the FlashAttention implementation from Phil Tillet a starting point.
|
||||||
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
||||||
@ -773,7 +778,8 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, do):
|
def backward(ctx, do):
|
||||||
q, kv, o, lse, bias = ctx.saved_tensors
|
q, kv, o, lse, bias = ctx.saved_tensors
|
||||||
assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
|
if len(ctx.needs_input_grad) >= 3:
|
||||||
|
assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
|
||||||
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||||
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user