diff --git a/benchmarks/benchmark_causal.py b/benchmarks/benchmark_causal.py index b3713e0..f290f53 100644 --- a/benchmarks/benchmark_causal.py +++ b/benchmarks/benchmark_causal.py @@ -93,8 +93,8 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, desc='PyTorch Attention') -benchmark_all(flash_attn_qkvpacked_func, qkv, causal, repeats=repeats, desc='FlashAttention Triton') -pytorch_profiler(flash_attn_qkvpacked_func, qkv, causal, backward=True) +benchmark_all(flash_attn_qkvpacked_func, qkv, causal=causal, repeats=repeats, desc='FlashAttention Triton') +pytorch_profiler(flash_attn_qkvpacked_func, qkv, causal=causal, backward=True) q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)] diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index ebc1bf8..78b7588 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -1,5 +1,10 @@ """ *Experimental* implementation of FlashAttention in Triton. +Tested with triton==2.0.0.dev20221202. +Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions +other than 64: +https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 +We'll update this implementation with the new Triton backend once this is fixed. We use the FlashAttention implementation from Phil Tillet a starting point. https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py @@ -773,7 +778,8 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod def backward(ctx, do): q, kv, o, lse, bias = ctx.saved_tensors - assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet' + if len(ctx.needs_input_grad) >= 3: + assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet' # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. with torch.inference_mode():