small fixes

This commit is contained in:
Ying Zhang 2024-09-16 15:50:55 -07:00
parent dff976a84a
commit be6c1b98c4
2 changed files with 13 additions and 20 deletions

View File

@ -251,11 +251,11 @@ public:
if constexpr (Varlen) {
if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
}
// if constexpr (Is_causal || Is_local) {
// int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
// int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
// if (m_block_min >= m_block_max) { continue; }
// }
if constexpr (Is_causal) {
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
if (m_block_min >= m_block_max) { continue; }
}
collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord);
}
}
@ -284,9 +284,6 @@ public:
if constexpr (Is_causal || Is_local) {
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
auto seqlen_q = collective_mainloop.get_seqlen_q(params.mainloop, bidb);
auto seqlen_k = collective_mainloop.get_seqlen_k(params.mainloop, bidb);
auto original_m_block_max = cute::ceil_div(seqlen_q, kBlockM);
if (m_block_min >= m_block_max) { // We exit early and write 0 to dK and dV
collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
continue;

View File

@ -1,8 +1,5 @@
import math
import sys
sys.path.remove("/home/yingz/llm_inference")
import pytest
import torch
import torch.nn.functional as F
@ -86,11 +83,10 @@ def test_flash_attn_output(
batch_size = 4
nheads = 6
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
# nheads_kv = 1
# batch_size = 1
# nheads = 1
# nheads_kv = 2
# batch_size = 9
# nheads = 6
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
print(f"window_size: {window_size}", flush=True)
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
@ -255,12 +251,12 @@ def test_flash_attn_varlen_output(
device = "cuda"
# set seed
torch.random.manual_seed(0)
# batch_size = 1
# nheads = 1
# nheads_kv = 1
batch_size = 9
nheads = 4
nheads_kv = 4
# batch_size = 9
# nheads = 6
# nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
nheads = 6
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))