diff --git a/benchmarks/benchmark_flash_attention.py b/benchmarks/benchmark_flash_attention.py index e0b6388..4b39540 100644 --- a/benchmarks/benchmark_flash_attention.py +++ b/benchmarks/benchmark_flash_attention.py @@ -1,4 +1,6 @@ -from functools import partial +# Install the newest triton version with +# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" +import pickle import math import torch import torch.nn as nn @@ -6,65 +8,161 @@ import torch.nn.functional as F from einops import rearrange, repeat -from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined -from flash_attn.bert_padding import unpad_input, pad_input -from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func +from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward +from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined + +from flash_attn import flash_attn_qkvpacked_func + +try: + from triton.ops.flash_attention import attention as attention_triton +except ImportError: + attention_triton = None + +try: + import xformers.ops as xops +except ImportError: + xops = None -def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False): +def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): + assert mode in ["fwd", "bwd", "fwd_bwd"] + f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) + return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) + +def efficiency(flop, time): + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + + +def attention_pytorch(qkv, dropout_p=0.0, causal=True): """ Arguments: qkv: (batch_size, seqlen, 3, nheads, head_dim) - attn_mask: (batch_size, seqlen) dropout_p: float Output: output: (batch_size, seqlen, nheads, head_dim) - attention: softmax after dropout """ - q, k, v = (qkv.float() if upcast else qkv).unbind(dim=2) - seqlen = qkv.shape[1] - d = qkv.shape[-1] - scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d)) - scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf')) + batch_size, seqlen, _, nheads, d = qkv.shape + q, k, v = qkv.unbind(dim=2) + q = rearrange(q, 'b t h d -> (b h) t d') + k = rearrange(k, 'b s h d -> (b h) d s') + softmax_scale = 1.0 / math.sqrt(d) + # Preallocate attn_weights for `baddbmm` + scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) + scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), + '(b h) t s -> b h t s', h=nheads) if causal: - causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1) - scores.masked_fill_(causal_mask, float('-inf')) + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1) attention_drop = F.dropout(attention, dropout_p) output = torch.einsum('bhts,bshd->bthd', attention_drop , v) - # return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype) return output.to(dtype=qkv.dtype) -torch.manual_seed(0) +def time_fwd_bwd(func, *args, **kwargs): + time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) + return time_f[1].mean, time_b[1].mean + + repeats = 30 -batch_size = 64 -nheads = 16 -seqlen = 1024 -n = 1024 -d = n // nheads -dropout_p = 0.1 -causal = False -dtype = torch.float16 device = 'cuda' +dtype = torch.float16 -x = torch.randn(batch_size, seqlen, n, device='cuda', dtype=dtype, requires_grad=True) -Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) +bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] +causal_vals = [False, True] +headdim_vals = [64, 128] +dim = 2048 +dropout_p = 0.0 -lengths = torch.randint(seqlen - 20, seqlen, (batch_size, 1), device='cuda') -attention_mask_bool = repeat(torch.arange(seqlen, device='cuda'), 's -> b s', b=batch_size) < lengths -attention_mask = torch.zeros(batch_size, seqlen, device='cuda', dtype=dtype) -attention_mask[~attention_mask_bool] = -10000.0 -attention_mask = rearrange(attention_mask, 'b s -> b 1 1 s') +methods = (["Flash2", "Pytorch"] + + (["Triton"] if attention_triton is not None else []) + + (["xformers"] if xops is not None else [])) -x_unpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(x, attention_mask_bool) -qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3, - h=nheads).detach().requires_grad_() -qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_() +time_f = {} +time_b = {} +time_f_b = {} +speed_f = {} +speed_b = {} +speed_f_b = {} +for causal in causal_vals: + for headdim in headdim_vals: + for batch_size, seqlen in bs_seqlen_vals: + config = (causal, headdim, batch_size, seqlen) + nheads = dim // headdim + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + f, b = time_fwd_bwd( + flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + time_f[config, "Flash2"] = f + time_b[config, "Flash2"] = b -fn = lambda qkv_unpad: flash_attn_varlen_qkvpacked_func( - qkv_unpad, cu_seqlens, max_seqlen_in_batch, dropout_p, causal=causal -) -benchmark_all(fn, qkv_unpad, repeats=repeats, desc='FlashAttention') -fn = lambda qkv: attention_ref(qkv, attention_mask_bool, dropout_p, causal=causal) -benchmark_all(fn, qkv, repeats=repeats, desc='PyTorch Standard Attention') + try: + qkv = qkv.detach().requires_grad_(True) + f, b = time_fwd_bwd( + attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + except: # Skip if OOM + f, b = float('nan'), float('nan') + time_f[config, "Pytorch"] = f + time_b[config, "Pytorch"] = b + + if attention_triton is not None: + q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, + requires_grad=True) for _ in range(3)] + # Try both values of sequence_parallel and pick the faster one + try: + f, b = time_fwd_bwd( + attention_triton, q, k, v, causal, headdim**(-0.5), + False, repeats=repeats, verbose=False + ) + except: + f, b = float('nan'), float('inf') + try: + _, b0 = time_fwd_bwd( + attention_triton, q, k, v, causal, headdim**(-0.5), + True, repeats=repeats, verbose=False + ) + except: + b0 = float('inf') + time_f[config, "Triton"] = f + time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan') + + if xops is not None: + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) for _ in range(3)] + f, b = time_fwd_bwd( + xops.memory_efficient_attention, q, k, v, + attn_bias=xops.LowerTriangularMask() if causal else None, + op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp) + ) + time_f[config, "xformers"] = f + time_b[config, "xformers"] = b + + print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") + for method in methods: + time_f_b[config, method] = time_f[config, method] + time_b[config, method] + speed_f[config, method] = efficiency( + flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), + time_f[config, method] + ) + speed_b[config, method] = efficiency( + flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"), + time_b[config, method] + ) + speed_f_b[config, method] = efficiency( + flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"), + time_f_b[config, method] + ) + print( + f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, " + f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, " + f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s" + ) + + +# with open('flash2_attn_time.plk', 'wb') as fp: +# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/flash_attn/utils/benchmark.py b/flash_attn/utils/benchmark.py index c74b1a8..a5f42c2 100644 --- a/flash_attn/utils/benchmark.py +++ b/flash_attn/utils/benchmark.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, Tri Dao. +# Copyright (c) 2023, Tri Dao. """ Useful functions for writing test code. """ import torch @@ -10,14 +10,12 @@ def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False, """ Use Pytorch Benchmark on the forward pass of an arbitrary function. """ if verbose: print(desc, '- Forward pass') - def fn_amp(*inputs, **kwinputs): + def amp_wrapper(*inputs, **kwinputs): with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): fn(*inputs, **kwinputs) - for _ in range(repeats): # warmup - fn_amp(*inputs, **kwinputs) t = benchmark.Timer( stmt='fn_amp(*inputs, **kwinputs)', - globals={'fn_amp': fn_amp, 'inputs': inputs, 'kwinputs': kwinputs}, + globals={'fn_amp': amp_wrapper, 'inputs': inputs, 'kwinputs': kwinputs}, num_threads=torch.get_num_threads(), ) m = t.timeit(repeats) @@ -40,13 +38,18 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True else: if grad.shape != y.shape: raise RuntimeError('Grad shape does not match output shape') - for _ in range(repeats): # warmup + def f(*inputs, y, grad): + # Set .grad to None to avoid extra operation of gradient accumulation + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None y.backward(grad, retain_graph=True) + t = benchmark.Timer( - stmt='y.backward(grad, retain_graph=True)', - globals={'y': y, 'grad': grad}, + stmt='f(*inputs, y=y, grad=grad)', + globals={'f': f, 'inputs': inputs, 'y': y, 'grad': grad}, num_threads=torch.get_num_threads(), - ) + ) m = t.timeit(repeats) if verbose: print(m) @@ -58,19 +61,24 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ if verbose: print(desc, '- Forward + Backward pass') + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError('Grad shape does not match output shape') def f(grad, *inputs, **kwinputs): + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): y = fn(*inputs, **kwinputs) if type(y) is tuple: y = y[0] - if grad is None: - grad = torch.randn_like(y) - else: - if grad.shape != y.shape: - raise RuntimeError('Grad shape does not match output shape') y.backward(grad, retain_graph=True) - for _ in range(repeats): # warmup - f(grad, *inputs, **kwinputs) t = benchmark.Timer( stmt='f(grad, *inputs, **kwinputs)', globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs}, @@ -82,6 +90,17 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True return t, m +def benchmark_fwd_bwd(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False, + amp_dtype=torch.float16, **kwinputs): + """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ + return ( + benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose, + amp=amp, amp_dtype=amp_dtype, **kwinputs), + benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, + amp=amp, amp_dtype=amp_dtype, **kwinputs), + ) + + def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs): """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ @@ -102,16 +121,15 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): g = torch.randn_like(fn(*inputs, **kwinputs)) for _ in range(30): # Warm up + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): - if backward: - for x in inputs: - if isinstance(x, torch.Tensor): - x.grad = None - # fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g) out = fn(*inputs, **kwinputs) # Backward should be done outside autocast if backward: - out.backward(g) + out.backward(g, retain_graph=True) activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA] with torch.profiler.profile( activities=activities, @@ -119,13 +137,13 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False # profile_memory=True, with_stack=True, ) as prof: + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): - if backward: - for x in inputs: - if isinstance(x, torch.Tensor): - x.grad = None out = fn(*inputs, **kwinputs) - if backward: out.backward(g) + if backward: out.backward(g, retain_graph=True) if verbose: # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) print(prof.key_averages().table(row_limit=50))