diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 306b162..e20f2e2 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -48,7 +48,7 @@ def convert_to_cudnn_type(torch_type): raise ValueError("Unsupported tensor data type.") -def cudnn_sdpa_setup(q, k, v, grad, causal=False): +def cudnn_sdpa_setup(q, k, v, grad, causal=False, varlen=False, seqlens=None): b, nheads, seqlen_q, headdim = q.shape _, _, seqlen_k, _ = k.shape assert v.shape == (b, nheads, seqlen_k, headdim) @@ -65,6 +65,10 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False): k_forward = graph_forward.tensor_like(k_gpu.detach()) v_forward = graph_forward.tensor_like(v_gpu.detach()) + seqlens_reshaped = seqlens.reshape(b, 1, 1, 1).contiguous().cuda() if varlen else None + seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None + seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None + o_forward, stats_forward = graph_forward.sdpa( name="sdpa", q=q_forward, @@ -73,6 +77,9 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False): is_inference=False, attn_scale=1.0 / math.sqrt(headdim), use_causal_mask=causal, + use_padding_mask=varlen, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, ) o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) @@ -90,6 +97,8 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False): v_forward: v_gpu, o_forward: o_gpu, stats_forward: stats_gpu, + seq_len_q: seqlens_reshaped, + seq_len_kv: seqlens_reshaped, } dQ_gpu = torch.empty_like(q_gpu) @@ -109,6 +118,8 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False): o_backward = graph_backward.tensor_like(o_gpu.detach()) dO_backward = graph_backward.tensor_like(dO_gpu.detach()) stats_backward = graph_backward.tensor_like(stats_gpu.detach()) + seq_len_q = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None + seq_len_kv = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward( name="sdpa_backward", @@ -120,6 +131,9 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False): stats=stats_backward, attn_scale=1.0 / math.sqrt(headdim), use_causal_mask=causal, + use_padding_mask=varlen, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, ) dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride()) @@ -142,6 +156,8 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False): dQ_backward: dQ_gpu, dK_backward: dK_gpu, dV_backward: dV_gpu, + seq_len_q: seqlens_reshaped, + seq_len_kv: seqlens_reshaped, } workspace = torch.empty( @@ -208,8 +224,12 @@ for mode in ['fwd']: for causal in [False, True]: # for causal in [True]: print(f"\n### {headdim = }, {seqlen = }, {causal = } ###") + # For var-seq-len + lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32) + cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda() if headdim <= 128 and cudnn is not None: cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal) + cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal, varlen=True, seqlens=lens) f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode) _, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2') if mode == 'bwd': @@ -234,6 +254,7 @@ for mode in ['fwd']: time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark if mode == 'fwd': _, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN') + _, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN') else: cudnn_sdpa_fwd() _, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') @@ -248,8 +269,6 @@ for mode in ['fwd']: q_var = q.reshape(-1, q.shape[-2], q.shape[-1]) k_var = k.reshape(-1, k.shape[-2], k.shape[-1]) v_var = v.reshape(-1, v.shape[-2], v.shape[-1]) - lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32) - cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda() time.sleep(1) _, m1_var = bench_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len') if mode == 'bwd': @@ -267,6 +286,7 @@ for mode in ['fwd']: print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS') if cudnn is not None: print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS') + print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS') if headdim == 128 or mode == 'fwd': print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS') print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')