From 4360cfc6a850ee2431cc7b21fdba4fdc6bec4d0f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 22 Mar 2023 01:34:38 -0700 Subject: [PATCH] [Triton] Fix benchmark_causal.py --- benchmarks/benchmark_causal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_causal.py b/benchmarks/benchmark_causal.py index f290f53..8226c88 100644 --- a/benchmarks/benchmark_causal.py +++ b/benchmarks/benchmark_causal.py @@ -93,8 +93,8 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, desc='PyTorch Attention') -benchmark_all(flash_attn_qkvpacked_func, qkv, causal=causal, repeats=repeats, desc='FlashAttention Triton') -pytorch_profiler(flash_attn_qkvpacked_func, qkv, causal=causal, backward=True) +benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton') +pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True) q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)]