diff --git a/benchmarks/benchmark_causal.py b/benchmarks/benchmark_causal.py index f290f53..8226c88 100644 --- a/benchmarks/benchmark_causal.py +++ b/benchmarks/benchmark_causal.py @@ -93,8 +93,8 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, desc='PyTorch Attention') -benchmark_all(flash_attn_qkvpacked_func, qkv, causal=causal, repeats=repeats, desc='FlashAttention Triton') -pytorch_profiler(flash_attn_qkvpacked_func, qkv, causal=causal, backward=True) +benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton') +pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True) q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)]