bwd benchmark + small fixes (#1129)
This commit is contained in:
parent
5d5bfbb619
commit
3669b25206
@ -48,14 +48,13 @@ def convert_to_cudnn_type(torch_type):
|
|||||||
raise ValueError("Unsupported tensor data type.")
|
raise ValueError("Unsupported tensor data type.")
|
||||||
|
|
||||||
|
|
||||||
def cudnn_sdpa_setup(q, k, v, grad, causal=False, varlen=False, seqlens=None):
|
def cudnn_sdpa_setup(q, k, v, grad, o, stats, causal=False, varlen=False, seqlens=None):
|
||||||
b, nheads, seqlen_q, headdim = q.shape
|
b, nheads, seqlen_q, headdim = q.shape
|
||||||
_, _, seqlen_k, _ = k.shape
|
_, nheads_kv, seqlen_k, _ = k.shape
|
||||||
assert v.shape == (b, nheads, seqlen_k, headdim)
|
assert v.shape == (b, nheads_kv, seqlen_k, headdim)
|
||||||
assert cudnn is not None, 'CUDNN is not available'
|
assert cudnn is not None, 'CUDNN is not available'
|
||||||
q_gpu, k_gpu, v_gpu = q, k, v
|
q_gpu, k_gpu, v_gpu = q, k, v
|
||||||
o_gpu = torch.empty_like(q_gpu)
|
o_gpu, stats_gpu = o, stats
|
||||||
stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
|
|
||||||
graph_forward = cudnn.pygraph(
|
graph_forward = cudnn.pygraph(
|
||||||
io_data_type=convert_to_cudnn_type(q.dtype),
|
io_data_type=convert_to_cudnn_type(q.dtype),
|
||||||
intermediate_data_type=cudnn.data_type.FLOAT,
|
intermediate_data_type=cudnn.data_type.FLOAT,
|
||||||
@ -65,7 +64,7 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False, varlen=False, seqlens=None):
|
|||||||
k_forward = graph_forward.tensor_like(k_gpu.detach())
|
k_forward = graph_forward.tensor_like(k_gpu.detach())
|
||||||
v_forward = graph_forward.tensor_like(v_gpu.detach())
|
v_forward = graph_forward.tensor_like(v_gpu.detach())
|
||||||
|
|
||||||
seqlens_reshaped = seqlens.reshape(b, 1, 1, 1).contiguous().cuda() if varlen else None
|
seqlens_reshaped = seqlens if varlen else None
|
||||||
seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
||||||
seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
||||||
|
|
||||||
@ -193,8 +192,8 @@ dim = 2048
|
|||||||
# headdim = 64
|
# headdim = 64
|
||||||
headdim = 256
|
headdim = 256
|
||||||
|
|
||||||
# for mode in ['fwd', 'bwd']:
|
for mode in ['fwd', 'bwd']:
|
||||||
for mode in ['fwd']:
|
# for mode in ['bwd']:
|
||||||
for headdim in [64, 128, 256]:
|
for headdim in [64, 128, 256]:
|
||||||
# for headdim in [128]:
|
# for headdim in [128]:
|
||||||
for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
|
for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
|
||||||
@ -206,31 +205,38 @@ for mode in ['fwd']:
|
|||||||
# seqlen = 512
|
# seqlen = 512
|
||||||
# nheads = 8
|
# nheads = 8
|
||||||
# headdim = 128
|
# headdim = 128
|
||||||
|
# nheads = 16
|
||||||
|
# headdim = 128
|
||||||
nheads_kv = nheads
|
nheads_kv = nheads
|
||||||
|
# nheads_kv = 1
|
||||||
|
|
||||||
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
||||||
requires_grad=True)
|
requires_grad=True)
|
||||||
q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
||||||
k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
|
||||||
v = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
|
||||||
q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
|
q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
|
||||||
k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
|
k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
|
||||||
v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
|
v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
|
||||||
grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
|
grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
|
||||||
grad_t = grad.transpose(1, 2).contiguous()
|
grad_t = grad.transpose(1, 2).contiguous()
|
||||||
|
o_t = torch.empty_like(q.transpose(1, 2))
|
||||||
|
stats = torch.empty(batch_size, nheads, seqlen, 1, dtype=torch.float32, device=q.device)
|
||||||
|
|
||||||
bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)
|
bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)
|
||||||
|
|
||||||
for causal in [False, True]:
|
for causal in [False, True]:
|
||||||
# for causal in [True]:
|
# for causal in [True]:
|
||||||
print(f"\n### {headdim = }, {seqlen = }, {causal = } ###")
|
print(f"\n### {mode = }, {batch_size = }, {headdim = }, {seqlen = }, {causal = } ###")
|
||||||
# For var-seq-len
|
# For var-seq-len
|
||||||
lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
|
lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
|
||||||
|
seqlens_cudnn = lens.reshape(batch_size, 1, 1, 1).contiguous().cuda()
|
||||||
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
|
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
|
||||||
if headdim <= 128 and cudnn is not None:
|
if headdim <= 128 and cudnn is not None:
|
||||||
cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal)
|
cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal)
|
||||||
cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal, varlen=True, seqlens=lens)
|
cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal, varlen=True, seqlens=seqlens_cudnn)
|
||||||
f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
|
f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
|
||||||
|
ref_o = flash_attn_func(q, k, v, dropout_p, causal=causal)
|
||||||
_, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
|
_, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
|
||||||
if mode == 'bwd':
|
if mode == 'bwd':
|
||||||
ref_dv, v.grad = v.grad.clone(), None
|
ref_dv, v.grad = v.grad.clone(), None
|
||||||
@ -238,7 +244,7 @@ for mode in ['fwd']:
|
|||||||
ref_dq, q.grad = q.grad.clone(), None
|
ref_dq, q.grad = q.grad.clone(), None
|
||||||
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
|
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
|
||||||
if headdim <= 128:
|
if headdim <= 128:
|
||||||
if triton_attention is not None:
|
if triton_attention is not None and nheads_kv == nheads:
|
||||||
if mode == 'fwd':
|
if mode == 'fwd':
|
||||||
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
||||||
_, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
|
_, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
|
||||||
@ -255,22 +261,31 @@ for mode in ['fwd']:
|
|||||||
if mode == 'fwd':
|
if mode == 'fwd':
|
||||||
_, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
_, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
||||||
_, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
|
_, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
|
||||||
|
cudnn_sdpa_fwd()
|
||||||
|
torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
|
||||||
|
cudnn_sdpa_fwd_varlen()
|
||||||
|
torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
|
||||||
else:
|
else:
|
||||||
cudnn_sdpa_fwd()
|
cudnn_sdpa_fwd()
|
||||||
_, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
_, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
||||||
|
_, m2_var = benchmark_forward(cudnn_sdpa_bwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
|
||||||
dq, dk, dv = cudnn_sdpa_bwd()
|
dq, dk, dv = cudnn_sdpa_bwd()
|
||||||
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
|
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
|
||||||
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
|
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
|
||||||
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
|
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
|
||||||
|
dq, dk, dv = cudnn_sdpa_bwd_varlen()
|
||||||
|
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
|
||||||
|
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
|
||||||
|
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
|
||||||
# pytorch_profiler(cudnn_sdpa, backward=False)
|
# pytorch_profiler(cudnn_sdpa, backward=False)
|
||||||
if headdim == 128 or mode == 'fwd':
|
|
||||||
|
if headdim <= 128 or mode == 'fwd':
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
_, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
|
_, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
|
||||||
q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
|
q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
|
||||||
k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
|
k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
|
||||||
v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
|
v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
_, m1_var = bench_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
|
|
||||||
if mode == 'bwd':
|
if mode == 'bwd':
|
||||||
dv, v.grad = v.grad.clone(), None
|
dv, v.grad = v.grad.clone(), None
|
||||||
dk, k.grad = k.grad.clone(), None
|
dk, k.grad = k.grad.clone(), None
|
||||||
@ -279,15 +294,21 @@ for mode in ['fwd']:
|
|||||||
torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
|
torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
|
||||||
torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)
|
torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)
|
||||||
|
|
||||||
|
bench_var_fn = bench_fn
|
||||||
|
if mode == 'bwd':
|
||||||
|
grad_var = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
|
||||||
|
bench_var_fn = partial(benchmark_backward, grad=grad_var)
|
||||||
|
_, m1_var = bench_var_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
|
||||||
|
|
||||||
# pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
|
# pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
|
||||||
print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
|
print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
|
||||||
if headdim <= 128:
|
if headdim <= 128:
|
||||||
if triton_attention is not None:
|
if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads:
|
||||||
print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
|
print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
|
||||||
if cudnn is not None:
|
if cudnn is not None:
|
||||||
print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
|
print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
|
||||||
print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS')
|
print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS')
|
||||||
if headdim == 128 or mode == 'fwd':
|
if headdim <= 128 or mode == 'fwd':
|
||||||
print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
|
print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
|
||||||
print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')
|
print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')
|
||||||
|
|
||||||
@ -288,7 +288,7 @@ struct CollectiveEpilogueFwd {
|
|||||||
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
|
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
|
||||||
);
|
);
|
||||||
static_assert(kBlockM <= NumMmaThreads);
|
static_assert(kBlockM <= NumMmaThreads);
|
||||||
if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
|
if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user