269 lines
7.2 KiB
Python
269 lines
7.2 KiB
Python
# Copyright (c) 2023, Tri Dao.
|
|
""" Useful functions for writing test code. """
|
|
|
|
import torch
|
|
import torch.utils.benchmark as benchmark
|
|
|
|
|
|
def benchmark_forward(
|
|
fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
|
|
):
|
|
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
|
|
if verbose:
|
|
print(desc, "- Forward pass")
|
|
|
|
def amp_wrapper(*inputs, **kwinputs):
|
|
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
fn(*inputs, **kwinputs)
|
|
|
|
t = benchmark.Timer(
|
|
stmt="fn_amp(*inputs, **kwinputs)",
|
|
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
|
|
num_threads=torch.get_num_threads(),
|
|
)
|
|
m = t.timeit(repeats)
|
|
if verbose:
|
|
print(m)
|
|
return t, m
|
|
|
|
|
|
def benchmark_backward(
|
|
fn,
|
|
*inputs,
|
|
grad=None,
|
|
repeats=10,
|
|
desc="",
|
|
verbose=True,
|
|
amp=False,
|
|
amp_dtype=torch.float16,
|
|
**kwinputs,
|
|
):
|
|
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
|
|
if verbose:
|
|
print(desc, "- Backward pass")
|
|
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
y = fn(*inputs, **kwinputs)
|
|
if type(y) is tuple:
|
|
y = y[0]
|
|
if grad is None:
|
|
grad = torch.randn_like(y)
|
|
else:
|
|
if grad.shape != y.shape:
|
|
raise RuntimeError("Grad shape does not match output shape")
|
|
|
|
def f(*inputs, y, grad):
|
|
# Set .grad to None to avoid extra operation of gradient accumulation
|
|
for x in inputs:
|
|
if isinstance(x, torch.Tensor):
|
|
x.grad = None
|
|
y.backward(grad, retain_graph=True)
|
|
|
|
t = benchmark.Timer(
|
|
stmt="f(*inputs, y=y, grad=grad)",
|
|
globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
|
|
num_threads=torch.get_num_threads(),
|
|
)
|
|
m = t.timeit(repeats)
|
|
if verbose:
|
|
print(m)
|
|
return t, m
|
|
|
|
|
|
def benchmark_combined(
|
|
fn,
|
|
*inputs,
|
|
grad=None,
|
|
repeats=10,
|
|
desc="",
|
|
verbose=True,
|
|
amp=False,
|
|
amp_dtype=torch.float16,
|
|
**kwinputs,
|
|
):
|
|
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
|
if verbose:
|
|
print(desc, "- Forward + Backward pass")
|
|
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
y = fn(*inputs, **kwinputs)
|
|
if type(y) is tuple:
|
|
y = y[0]
|
|
if grad is None:
|
|
grad = torch.randn_like(y)
|
|
else:
|
|
if grad.shape != y.shape:
|
|
raise RuntimeError("Grad shape does not match output shape")
|
|
|
|
def f(grad, *inputs, **kwinputs):
|
|
for x in inputs:
|
|
if isinstance(x, torch.Tensor):
|
|
x.grad = None
|
|
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
y = fn(*inputs, **kwinputs)
|
|
if type(y) is tuple:
|
|
y = y[0]
|
|
y.backward(grad, retain_graph=True)
|
|
|
|
t = benchmark.Timer(
|
|
stmt="f(grad, *inputs, **kwinputs)",
|
|
globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
|
|
num_threads=torch.get_num_threads(),
|
|
)
|
|
m = t.timeit(repeats)
|
|
if verbose:
|
|
print(m)
|
|
return t, m
|
|
|
|
|
|
def benchmark_fwd_bwd(
|
|
fn,
|
|
*inputs,
|
|
grad=None,
|
|
repeats=10,
|
|
desc="",
|
|
verbose=True,
|
|
amp=False,
|
|
amp_dtype=torch.float16,
|
|
**kwinputs,
|
|
):
|
|
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
|
return (
|
|
benchmark_forward(
|
|
fn,
|
|
*inputs,
|
|
repeats=repeats,
|
|
desc=desc,
|
|
verbose=verbose,
|
|
amp=amp,
|
|
amp_dtype=amp_dtype,
|
|
**kwinputs,
|
|
),
|
|
benchmark_backward(
|
|
fn,
|
|
*inputs,
|
|
grad=grad,
|
|
repeats=repeats,
|
|
desc=desc,
|
|
verbose=verbose,
|
|
amp=amp,
|
|
amp_dtype=amp_dtype,
|
|
**kwinputs,
|
|
),
|
|
)
|
|
|
|
|
|
def benchmark_all(
|
|
fn,
|
|
*inputs,
|
|
grad=None,
|
|
repeats=10,
|
|
desc="",
|
|
verbose=True,
|
|
amp=False,
|
|
amp_dtype=torch.float16,
|
|
**kwinputs,
|
|
):
|
|
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
|
return (
|
|
benchmark_forward(
|
|
fn,
|
|
*inputs,
|
|
repeats=repeats,
|
|
desc=desc,
|
|
verbose=verbose,
|
|
amp=amp,
|
|
amp_dtype=amp_dtype,
|
|
**kwinputs,
|
|
),
|
|
benchmark_backward(
|
|
fn,
|
|
*inputs,
|
|
grad=grad,
|
|
repeats=repeats,
|
|
desc=desc,
|
|
verbose=verbose,
|
|
amp=amp,
|
|
amp_dtype=amp_dtype,
|
|
**kwinputs,
|
|
),
|
|
benchmark_combined(
|
|
fn,
|
|
*inputs,
|
|
grad=grad,
|
|
repeats=repeats,
|
|
desc=desc,
|
|
verbose=verbose,
|
|
amp=amp,
|
|
amp_dtype=amp_dtype,
|
|
**kwinputs,
|
|
),
|
|
)
|
|
|
|
|
|
def pytorch_profiler(
|
|
fn,
|
|
*inputs,
|
|
trace_filename=None,
|
|
backward=False,
|
|
amp=False,
|
|
amp_dtype=torch.float16,
|
|
cpu=False,
|
|
verbose=True,
|
|
**kwinputs,
|
|
):
|
|
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
|
|
if backward:
|
|
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
out = fn(*inputs, **kwinputs)
|
|
if type(out) is tuple:
|
|
out = out[0]
|
|
g = torch.randn_like(out)
|
|
for _ in range(30): # Warm up
|
|
if backward:
|
|
for x in inputs:
|
|
if isinstance(x, torch.Tensor):
|
|
x.grad = None
|
|
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
out = fn(*inputs, **kwinputs)
|
|
if type(out) is tuple:
|
|
out = out[0]
|
|
# Backward should be done outside autocast
|
|
if backward:
|
|
out.backward(g, retain_graph=True)
|
|
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
|
|
torch.profiler.ProfilerActivity.CUDA
|
|
]
|
|
with torch.profiler.profile(
|
|
activities=activities,
|
|
record_shapes=True,
|
|
# profile_memory=True,
|
|
with_stack=True,
|
|
) as prof:
|
|
if backward:
|
|
for x in inputs:
|
|
if isinstance(x, torch.Tensor):
|
|
x.grad = None
|
|
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
out = fn(*inputs, **kwinputs)
|
|
if type(out) is tuple:
|
|
out = out[0]
|
|
if backward:
|
|
out.backward(g, retain_graph=True)
|
|
if verbose:
|
|
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
|
|
print(prof.key_averages().table(row_limit=50))
|
|
if trace_filename is not None:
|
|
prof.export_chrome_trace(trace_filename)
|
|
|
|
|
|
def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
torch.cuda.synchronize()
|
|
fn(*inputs, **kwinputs)
|
|
torch.cuda.synchronize()
|
|
mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
|
|
if verbose:
|
|
print(f"{desc} max memory: {mem}GB")
|
|
torch.cuda.empty_cache()
|
|
return mem
|