[Triton] Fix benchmark_causal.py
This commit is contained in:
parent
5d079fdd7a
commit
4360cfc6a8
@ -93,8 +93,8 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b
|
||||
benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
|
||||
repeats=repeats, desc='PyTorch Attention')
|
||||
|
||||
benchmark_all(flash_attn_qkvpacked_func, qkv, causal=causal, repeats=repeats, desc='FlashAttention Triton')
|
||||
pytorch_profiler(flash_attn_qkvpacked_func, qkv, causal=causal, backward=True)
|
||||
benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')
|
||||
pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)
|
||||
|
||||
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
||||
requires_grad=True) for _ in range(3)]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user