From 3669b25206d5938e3cc74a5f7860e31c38af8204 Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Mon, 5 Aug 2024 21:27:52 -0700 Subject: [PATCH] bwd benchmark + small fixes (#1129) --- hopper/benchmark_attn.py | 57 ++++++++++++++++++++++---------- hopper/epilogue_fwd_sm90_tma.hpp | 2 +- 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index e20f2e2..74d2ce3 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -48,14 +48,13 @@ def convert_to_cudnn_type(torch_type): raise ValueError("Unsupported tensor data type.") -def cudnn_sdpa_setup(q, k, v, grad, causal=False, varlen=False, seqlens=None): +def cudnn_sdpa_setup(q, k, v, grad, o, stats, causal=False, varlen=False, seqlens=None): b, nheads, seqlen_q, headdim = q.shape - _, _, seqlen_k, _ = k.shape - assert v.shape == (b, nheads, seqlen_k, headdim) + _, nheads_kv, seqlen_k, _ = k.shape + assert v.shape == (b, nheads_kv, seqlen_k, headdim) assert cudnn is not None, 'CUDNN is not available' q_gpu, k_gpu, v_gpu = q, k, v - o_gpu = torch.empty_like(q_gpu) - stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) + o_gpu, stats_gpu = o, stats graph_forward = cudnn.pygraph( io_data_type=convert_to_cudnn_type(q.dtype), intermediate_data_type=cudnn.data_type.FLOAT, @@ -65,7 +64,7 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False, varlen=False, seqlens=None): k_forward = graph_forward.tensor_like(k_gpu.detach()) v_forward = graph_forward.tensor_like(v_gpu.detach()) - seqlens_reshaped = seqlens.reshape(b, 1, 1, 1).contiguous().cuda() if varlen else None + seqlens_reshaped = seqlens if varlen else None seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None @@ -193,8 +192,8 @@ dim = 2048 # headdim = 64 headdim = 256 -# for mode in ['fwd', 'bwd']: -for mode in ['fwd']: +for mode in ['fwd', 'bwd']: +# for mode in ['bwd']: for headdim in [64, 128, 256]: # for headdim in [128]: for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]: @@ -206,31 +205,38 @@ for mode in ['fwd']: # seqlen = 512 # nheads = 8 # headdim = 128 + # nheads = 16 + # headdim = 128 nheads_kv = nheads + # nheads_kv = 1 qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) - k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) - v = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True) q_t = q.transpose(1, 2).contiguous().detach().requires_grad_() k_t = k.transpose(1, 2).contiguous().detach().requires_grad_() v_t = k.transpose(1, 2).contiguous().detach().requires_grad_() grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype) grad_t = grad.transpose(1, 2).contiguous() + o_t = torch.empty_like(q.transpose(1, 2)) + stats = torch.empty(batch_size, nheads, seqlen, 1, dtype=torch.float32, device=q.device) bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad) for causal in [False, True]: # for causal in [True]: - print(f"\n### {headdim = }, {seqlen = }, {causal = } ###") + print(f"\n### {mode = }, {batch_size = }, {headdim = }, {seqlen = }, {causal = } ###") # For var-seq-len lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32) + seqlens_cudnn = lens.reshape(batch_size, 1, 1, 1).contiguous().cuda() cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda() if headdim <= 128 and cudnn is not None: - cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal) - cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal, varlen=True, seqlens=lens) + cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal) + cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal, varlen=True, seqlens=seqlens_cudnn) f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode) + ref_o = flash_attn_func(q, k, v, dropout_p, causal=causal) _, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2') if mode == 'bwd': ref_dv, v.grad = v.grad.clone(), None @@ -238,7 +244,7 @@ for mode in ['fwd']: ref_dq, q.grad = q.grad.clone(), None # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False) if headdim <= 128: - if triton_attention is not None: + if triton_attention is not None and nheads_kv == nheads: if mode == 'fwd': time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark _, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton') @@ -255,22 +261,31 @@ for mode in ['fwd']: if mode == 'fwd': _, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN') _, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN') + cudnn_sdpa_fwd() + torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05) + cudnn_sdpa_fwd_varlen() + torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05) else: cudnn_sdpa_fwd() _, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') + _, m2_var = benchmark_forward(cudnn_sdpa_bwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN') dq, dk, dv = cudnn_sdpa_bwd() torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05) torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05) torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05) + dq, dk, dv = cudnn_sdpa_bwd_varlen() + torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05) + torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05) + torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05) # pytorch_profiler(cudnn_sdpa, backward=False) - if headdim == 128 or mode == 'fwd': + + if headdim <= 128 or mode == 'fwd': time.sleep(1) _, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3') q_var = q.reshape(-1, q.shape[-2], q.shape[-1]) k_var = k.reshape(-1, k.shape[-2], k.shape[-1]) v_var = v.reshape(-1, v.shape[-2], v.shape[-1]) time.sleep(1) - _, m1_var = bench_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len') if mode == 'bwd': dv, v.grad = v.grad.clone(), None dk, k.grad = k.grad.clone(), None @@ -279,15 +294,21 @@ for mode in ['fwd']: torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05) torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05) + bench_var_fn = bench_fn + if mode == 'bwd': + grad_var = grad.reshape(-1, grad.shape[-2], grad.shape[-1]) + bench_var_fn = partial(benchmark_backward, grad=grad_var) + _, m1_var = bench_var_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len') + # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False) print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS') if headdim <= 128: - if triton_attention is not None: + if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads: print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS') if cudnn is not None: print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS') print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS') - if headdim == 128 or mode == 'fwd': + if headdim <= 128 or mode == 'fwd': print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS') print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS') \ No newline at end of file diff --git a/hopper/epilogue_fwd_sm90_tma.hpp b/hopper/epilogue_fwd_sm90_tma.hpp index ec9d45f..5133c55 100644 --- a/hopper/epilogue_fwd_sm90_tma.hpp +++ b/hopper/epilogue_fwd_sm90_tma.hpp @@ -288,7 +288,7 @@ struct CollectiveEpilogueFwd { gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM ); static_assert(kBlockM <= NumMmaThreads); - if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; } + if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; } } };