diff --git a/hopper/flash_bwd_kernel.h b/hopper/flash_bwd_kernel.h index 63cfd78..ccd53f3 100644 --- a/hopper/flash_bwd_kernel.h +++ b/hopper/flash_bwd_kernel.h @@ -251,11 +251,11 @@ public: if constexpr (Varlen) { if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; } } - // if constexpr (Is_causal || Is_local) { - // int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); - // int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); - // if (m_block_min >= m_block_max) { continue; } - // } + if constexpr (Is_causal) { + int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); + int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM); + if (m_block_min >= m_block_max) { continue; } + } collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord); } } @@ -284,9 +284,6 @@ public: if constexpr (Is_causal || Is_local) { int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); - auto seqlen_q = collective_mainloop.get_seqlen_q(params.mainloop, bidb); - auto seqlen_k = collective_mainloop.get_seqlen_k(params.mainloop, bidb); - auto original_m_block_max = cute::ceil_div(seqlen_q, kBlockM); if (m_block_min >= m_block_max) { // We exit early and write 0 to dK and dV collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); continue; diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 82f344d..dc8b35b 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -1,8 +1,5 @@ import math -import sys -sys.path.remove("/home/yingz/llm_inference") - import pytest import torch import torch.nn.functional as F @@ -86,11 +83,10 @@ def test_flash_attn_output( batch_size = 4 nheads = 6 nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - # nheads_kv = 1 - # batch_size = 1 - # nheads = 1 + # nheads_kv = 2 + # batch_size = 9 + # nheads = 6 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - print(f"window_size: {window_size}", flush=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) @@ -255,12 +251,12 @@ def test_flash_attn_varlen_output( device = "cuda" # set seed torch.random.manual_seed(0) + # batch_size = 1 + # nheads = 1 + # nheads_kv = 1 batch_size = 9 - nheads = 4 - nheads_kv = 4 - # batch_size = 9 - # nheads = 6 - # nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + nheads = 6 + nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))