small fixes
This commit is contained in:
parent
dff976a84a
commit
be6c1b98c4
@ -251,11 +251,11 @@ public:
|
||||
if constexpr (Varlen) {
|
||||
if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
|
||||
}
|
||||
// if constexpr (Is_causal || Is_local) {
|
||||
// int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
|
||||
// int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
|
||||
// if (m_block_min >= m_block_max) { continue; }
|
||||
// }
|
||||
if constexpr (Is_causal) {
|
||||
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
|
||||
int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
|
||||
if (m_block_min >= m_block_max) { continue; }
|
||||
}
|
||||
collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord);
|
||||
}
|
||||
}
|
||||
@ -284,9 +284,6 @@ public:
|
||||
if constexpr (Is_causal || Is_local) {
|
||||
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
|
||||
int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
|
||||
auto seqlen_q = collective_mainloop.get_seqlen_q(params.mainloop, bidb);
|
||||
auto seqlen_k = collective_mainloop.get_seqlen_k(params.mainloop, bidb);
|
||||
auto original_m_block_max = cute::ceil_div(seqlen_q, kBlockM);
|
||||
if (m_block_min >= m_block_max) { // We exit early and write 0 to dK and dV
|
||||
collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
|
||||
continue;
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
import math
|
||||
|
||||
import sys
|
||||
sys.path.remove("/home/yingz/llm_inference")
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -86,11 +83,10 @@ def test_flash_attn_output(
|
||||
batch_size = 4
|
||||
nheads = 6
|
||||
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
|
||||
# nheads_kv = 1
|
||||
# batch_size = 1
|
||||
# nheads = 1
|
||||
# nheads_kv = 2
|
||||
# batch_size = 9
|
||||
# nheads = 6
|
||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||
print(f"window_size: {window_size}", flush=True)
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
|
||||
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
|
||||
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
|
||||
@ -255,12 +251,12 @@ def test_flash_attn_varlen_output(
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# batch_size = 1
|
||||
# nheads = 1
|
||||
# nheads_kv = 1
|
||||
batch_size = 9
|
||||
nheads = 4
|
||||
nheads_kv = 4
|
||||
# batch_size = 9
|
||||
# nheads = 6
|
||||
# nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
|
||||
nheads = 6
|
||||
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
|
||||
|
||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user