flash-attention/benchmarks/benchmark_flash_attention.py

71 lines
2.7 KiB
Python
Raw Normal View History

2022-05-27 04:57:38 +08:00
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
2022-10-24 03:50:00 +08:00
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
2022-06-02 09:50:26 +08:00
from flash_attn.bert_padding import unpad_input, pad_input
2022-07-01 11:26:04 +08:00
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
2022-05-27 04:57:38 +08:00
def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
attn_mask: (batch_size, seqlen)
dropout_p: float
Output:
output: (batch_size, seqlen, nheads, head_dim)
attention: softmax after dropout
"""
q, k, v = (qkv.float() if upcast else qkv).unbind(dim=2)
seqlen = qkv.shape[1]
d = qkv.shape[-1]
scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d))
scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
if causal:
causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)
scores.masked_fill_(causal_mask, float('-inf'))
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
# return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
return output.to(dtype=qkv.dtype)
torch.manual_seed(0)
repeats = 30
batch_size = 64
nheads = 16
seqlen = 1024
n = 1024
d = n // nheads
dropout_p = 0.1
causal = False
dtype = torch.float16
device = 'cuda'
x = torch.randn(batch_size, seqlen, n, device='cuda', dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
lengths = torch.randint(seqlen - 20, seqlen, (batch_size, 1), device='cuda')
attention_mask_bool = repeat(torch.arange(seqlen, device='cuda'), 's -> b s', b=batch_size) < lengths
attention_mask = torch.zeros(batch_size, seqlen, device='cuda', dtype=dtype)
attention_mask[~attention_mask_bool] = -10000.0
attention_mask = rearrange(attention_mask, 'b s -> b 1 1 s')
x_unpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(x, attention_mask_bool)
qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3,
h=nheads).detach().requires_grad_()
qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_()
2022-07-01 11:26:04 +08:00
fn = lambda qkv_unpad: flash_attn_unpadded_qkvpacked_func(
qkv_unpad, cu_seqlens, max_seqlen_in_batch, dropout_p, causal=causal
)
2022-05-27 04:57:38 +08:00
benchmark_all(fn, qkv_unpad, repeats=repeats, desc='FlashAttention')
fn = lambda qkv: attention_ref(qkv, attention_mask_bool, dropout_p, causal=causal)
benchmark_all(fn, qkv, repeats=repeats, desc='PyTorch Standard Attention')