diff --git a/benchmarks/benchmark_flash_attention.py b/benchmarks/benchmark_flash_attention.py index 92b4b76..03ae29d 100644 --- a/benchmarks/benchmark_flash_attention.py +++ b/benchmarks/benchmark_flash_attention.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from einops import rearrange, repeat -from benchmarks.utils import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined +from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined from flash_attn.bert_padding import unpad_input, pad_input from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func diff --git a/benchmarks/utils.py b/flash_attn/utils/benchmark.py similarity index 57% rename from benchmarks/utils.py rename to flash_attn/utils/benchmark.py index 0cc472e..c74b1a8 100644 --- a/benchmarks/utils.py +++ b/flash_attn/utils/benchmark.py @@ -1,17 +1,23 @@ -# Adapted from https://github.com/HazyResearch/hippo/blob/datasets/benchmark/utils.py +# Copyright (c) 2022, Tri Dao. """ Useful functions for writing test code. """ import torch import torch.utils.benchmark as benchmark -def benchmark_forward(fn, *inputs, min_run_time = 0.2, repeats = 10, desc='', verbose=True, **kwinputs): +def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False, + amp_dtype=torch.float16, **kwinputs): """ Use Pytorch Benchmark on the forward pass of an arbitrary function. """ if verbose: print(desc, '- Forward pass') + def fn_amp(*inputs, **kwinputs): + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): + fn(*inputs, **kwinputs) + for _ in range(repeats): # warmup + fn_amp(*inputs, **kwinputs) t = benchmark.Timer( - stmt='fn(*inputs, **kwinputs)', - globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs}, + stmt='fn_amp(*inputs, **kwinputs)', + globals={'fn_amp': fn_amp, 'inputs': inputs, 'kwinputs': kwinputs}, num_threads=torch.get_num_threads(), ) m = t.timeit(repeats) @@ -20,50 +26,51 @@ def benchmark_forward(fn, *inputs, min_run_time = 0.2, repeats = 10, desc='', ve return t, m -def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, **kwinputs): +def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False, + amp_dtype=torch.float16, **kwinputs): """ Use Pytorch Benchmark on the backward pass of an arbitrary function. """ if verbose: print(desc, '- Backward pass') - y = fn(*inputs, **kwinputs) - if type(y) is tuple: - y = y[0] + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] if grad is None: grad = torch.randn_like(y) else: if grad.shape != y.shape: raise RuntimeError('Grad shape does not match output shape') + for _ in range(repeats): # warmup + y.backward(grad, retain_graph=True) t = benchmark.Timer( stmt='y.backward(grad, retain_graph=True)', globals={'y': y, 'grad': grad}, num_threads=torch.get_num_threads(), - ) + ) m = t.timeit(repeats) if verbose: print(m) return t, m -def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, **kwinputs): +def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False, + amp_dtype=torch.float16, **kwinputs): """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ if verbose: print(desc, '- Forward + Backward pass') - # y = fn(*inputs, **kwinputs) - # if grad is None: - # grad = torch.randn_like(y) - # else: - # if grad.shape != y.shape: - # raise RuntimeError('Grad shape does not match output shape') - # del y def f(grad, *inputs, **kwinputs): - y = fn(*inputs, **kwinputs) - if type(y) is tuple: - y = y[0] + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] if grad is None: grad = torch.randn_like(y) else: if grad.shape != y.shape: raise RuntimeError('Grad shape does not match output shape') y.backward(grad, retain_graph=True) + for _ in range(repeats): # warmup + f(grad, *inputs, **kwinputs) t = benchmark.Timer( stmt='f(grad, *inputs, **kwinputs)', globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs}, @@ -75,43 +82,53 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True return t, m -def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, **kwinputs): +def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False, + amp_dtype=torch.float16, **kwinputs): """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ return ( - benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose, **kwinputs), + benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose, + amp=amp, amp_dtype=amp_dtype, **kwinputs), benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, - **kwinputs), + amp=amp, amp_dtype=amp_dtype, **kwinputs), benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, - **kwinputs), + amp=amp, amp_dtype=amp_dtype, **kwinputs), ) -def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False, verbose=True): +def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False, + amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs): """ Wrap benchmark functions in Pytorch profiler to see CUDA information. """ if backward: - g = torch.randn_like(fn(*inputs)) - for _ in range(10): # Warm up - with torch.autocast(device_type='cuda', enabled=amp): + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): + g = torch.randn_like(fn(*inputs, **kwinputs)) + for _ in range(30): # Warm up + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): if backward: for x in inputs: if isinstance(x, torch.Tensor): x.grad = None - fn(*inputs) if not backward else fn(*inputs).backward(g) + # fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g) + out = fn(*inputs, **kwinputs) + # Backward should be done outside autocast + if backward: + out.backward(g) + activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA] with torch.profiler.profile( - # activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,], - activities=[torch.profiler.ProfilerActivity.CUDA,], + activities=activities, record_shapes=True, # profile_memory=True, with_stack=True, ) as prof: - with torch.autocast(device_type='cuda', enabled=amp): + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): if backward: for x in inputs: if isinstance(x, torch.Tensor): x.grad = None - fn(*inputs) if not backward else fn(*inputs).backward(g) + out = fn(*inputs, **kwinputs) + if backward: out.backward(g) if verbose: - print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) + # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) + print(prof.key_averages().table(row_limit=50)) if trace_filename is not None: prof.export_chrome_trace(trace_filename) @@ -124,6 +141,6 @@ def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs): torch.cuda.synchronize() mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000) if verbose: - print(f'{desc} max memory: ', mem) + print(f'{desc} max memory: {mem}GB') torch.cuda.empty_cache() return mem