add cudnn benchmark for var-len
This commit is contained in:
parent
5018ac6ac5
commit
c7f20a2d31
@ -48,7 +48,7 @@ def convert_to_cudnn_type(torch_type):
|
|||||||
raise ValueError("Unsupported tensor data type.")
|
raise ValueError("Unsupported tensor data type.")
|
||||||
|
|
||||||
|
|
||||||
def cudnn_sdpa_setup(q, k, v, grad, causal=False):
|
def cudnn_sdpa_setup(q, k, v, grad, causal=False, varlen=False, seqlens=None):
|
||||||
b, nheads, seqlen_q, headdim = q.shape
|
b, nheads, seqlen_q, headdim = q.shape
|
||||||
_, _, seqlen_k, _ = k.shape
|
_, _, seqlen_k, _ = k.shape
|
||||||
assert v.shape == (b, nheads, seqlen_k, headdim)
|
assert v.shape == (b, nheads, seqlen_k, headdim)
|
||||||
@ -65,6 +65,10 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False):
|
|||||||
k_forward = graph_forward.tensor_like(k_gpu.detach())
|
k_forward = graph_forward.tensor_like(k_gpu.detach())
|
||||||
v_forward = graph_forward.tensor_like(v_gpu.detach())
|
v_forward = graph_forward.tensor_like(v_gpu.detach())
|
||||||
|
|
||||||
|
seqlens_reshaped = seqlens.reshape(b, 1, 1, 1).contiguous().cuda() if varlen else None
|
||||||
|
seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
||||||
|
seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
||||||
|
|
||||||
o_forward, stats_forward = graph_forward.sdpa(
|
o_forward, stats_forward = graph_forward.sdpa(
|
||||||
name="sdpa",
|
name="sdpa",
|
||||||
q=q_forward,
|
q=q_forward,
|
||||||
@ -73,6 +77,9 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False):
|
|||||||
is_inference=False,
|
is_inference=False,
|
||||||
attn_scale=1.0 / math.sqrt(headdim),
|
attn_scale=1.0 / math.sqrt(headdim),
|
||||||
use_causal_mask=causal,
|
use_causal_mask=causal,
|
||||||
|
use_padding_mask=varlen,
|
||||||
|
seq_len_q=seq_len_q,
|
||||||
|
seq_len_kv=seq_len_kv,
|
||||||
)
|
)
|
||||||
|
|
||||||
o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
|
o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
|
||||||
@ -90,6 +97,8 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False):
|
|||||||
v_forward: v_gpu,
|
v_forward: v_gpu,
|
||||||
o_forward: o_gpu,
|
o_forward: o_gpu,
|
||||||
stats_forward: stats_gpu,
|
stats_forward: stats_gpu,
|
||||||
|
seq_len_q: seqlens_reshaped,
|
||||||
|
seq_len_kv: seqlens_reshaped,
|
||||||
}
|
}
|
||||||
|
|
||||||
dQ_gpu = torch.empty_like(q_gpu)
|
dQ_gpu = torch.empty_like(q_gpu)
|
||||||
@ -109,6 +118,8 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False):
|
|||||||
o_backward = graph_backward.tensor_like(o_gpu.detach())
|
o_backward = graph_backward.tensor_like(o_gpu.detach())
|
||||||
dO_backward = graph_backward.tensor_like(dO_gpu.detach())
|
dO_backward = graph_backward.tensor_like(dO_gpu.detach())
|
||||||
stats_backward = graph_backward.tensor_like(stats_gpu.detach())
|
stats_backward = graph_backward.tensor_like(stats_gpu.detach())
|
||||||
|
seq_len_q = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
||||||
|
seq_len_kv = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
||||||
|
|
||||||
dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
|
dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
|
||||||
name="sdpa_backward",
|
name="sdpa_backward",
|
||||||
@ -120,6 +131,9 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False):
|
|||||||
stats=stats_backward,
|
stats=stats_backward,
|
||||||
attn_scale=1.0 / math.sqrt(headdim),
|
attn_scale=1.0 / math.sqrt(headdim),
|
||||||
use_causal_mask=causal,
|
use_causal_mask=causal,
|
||||||
|
use_padding_mask=varlen,
|
||||||
|
seq_len_q=seq_len_q,
|
||||||
|
seq_len_kv=seq_len_kv,
|
||||||
)
|
)
|
||||||
|
|
||||||
dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
|
dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
|
||||||
@ -142,6 +156,8 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False):
|
|||||||
dQ_backward: dQ_gpu,
|
dQ_backward: dQ_gpu,
|
||||||
dK_backward: dK_gpu,
|
dK_backward: dK_gpu,
|
||||||
dV_backward: dV_gpu,
|
dV_backward: dV_gpu,
|
||||||
|
seq_len_q: seqlens_reshaped,
|
||||||
|
seq_len_kv: seqlens_reshaped,
|
||||||
}
|
}
|
||||||
|
|
||||||
workspace = torch.empty(
|
workspace = torch.empty(
|
||||||
@ -208,8 +224,12 @@ for mode in ['fwd']:
|
|||||||
for causal in [False, True]:
|
for causal in [False, True]:
|
||||||
# for causal in [True]:
|
# for causal in [True]:
|
||||||
print(f"\n### {headdim = }, {seqlen = }, {causal = } ###")
|
print(f"\n### {headdim = }, {seqlen = }, {causal = } ###")
|
||||||
|
# For var-seq-len
|
||||||
|
lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
|
||||||
|
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
|
||||||
if headdim <= 128 and cudnn is not None:
|
if headdim <= 128 and cudnn is not None:
|
||||||
cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal)
|
cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal)
|
||||||
|
cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal, varlen=True, seqlens=lens)
|
||||||
f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
|
f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
|
||||||
_, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
|
_, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
|
||||||
if mode == 'bwd':
|
if mode == 'bwd':
|
||||||
@ -234,6 +254,7 @@ for mode in ['fwd']:
|
|||||||
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
||||||
if mode == 'fwd':
|
if mode == 'fwd':
|
||||||
_, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
_, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
||||||
|
_, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
|
||||||
else:
|
else:
|
||||||
cudnn_sdpa_fwd()
|
cudnn_sdpa_fwd()
|
||||||
_, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
_, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
||||||
@ -248,8 +269,6 @@ for mode in ['fwd']:
|
|||||||
q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
|
q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
|
||||||
k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
|
k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
|
||||||
v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
|
v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
|
||||||
lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
|
|
||||||
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
_, m1_var = bench_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
|
_, m1_var = bench_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
|
||||||
if mode == 'bwd':
|
if mode == 'bwd':
|
||||||
@ -267,6 +286,7 @@ for mode in ['fwd']:
|
|||||||
print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
|
print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
|
||||||
if cudnn is not None:
|
if cudnn is not None:
|
||||||
print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
|
print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
|
||||||
|
print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS')
|
||||||
if headdim == 128 or mode == 'fwd':
|
if headdim == 128 or mode == 'fwd':
|
||||||
print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
|
print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
|
||||||
print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')
|
print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user