small fixes
This commit is contained in:
parent
dff976a84a
commit
be6c1b98c4
@ -251,11 +251,11 @@ public:
|
|||||||
if constexpr (Varlen) {
|
if constexpr (Varlen) {
|
||||||
if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
|
if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
|
||||||
}
|
}
|
||||||
// if constexpr (Is_causal || Is_local) {
|
if constexpr (Is_causal) {
|
||||||
// int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
|
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
|
||||||
// int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
|
int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
|
||||||
// if (m_block_min >= m_block_max) { continue; }
|
if (m_block_min >= m_block_max) { continue; }
|
||||||
// }
|
}
|
||||||
collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord);
|
collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -284,9 +284,6 @@ public:
|
|||||||
if constexpr (Is_causal || Is_local) {
|
if constexpr (Is_causal || Is_local) {
|
||||||
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
|
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
|
||||||
int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
|
int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
|
||||||
auto seqlen_q = collective_mainloop.get_seqlen_q(params.mainloop, bidb);
|
|
||||||
auto seqlen_k = collective_mainloop.get_seqlen_k(params.mainloop, bidb);
|
|
||||||
auto original_m_block_max = cute::ceil_div(seqlen_q, kBlockM);
|
|
||||||
if (m_block_min >= m_block_max) { // We exit early and write 0 to dK and dV
|
if (m_block_min >= m_block_max) { // We exit early and write 0 to dK and dV
|
||||||
collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
|
collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
|
||||||
continue;
|
continue;
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import sys
|
|
||||||
sys.path.remove("/home/yingz/llm_inference")
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -86,11 +83,10 @@ def test_flash_attn_output(
|
|||||||
batch_size = 4
|
batch_size = 4
|
||||||
nheads = 6
|
nheads = 6
|
||||||
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
|
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
|
||||||
# nheads_kv = 1
|
# nheads_kv = 2
|
||||||
# batch_size = 1
|
# batch_size = 9
|
||||||
# nheads = 1
|
# nheads = 6
|
||||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||||
print(f"window_size: {window_size}", flush=True)
|
|
||||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
|
||||||
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
|
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
|
||||||
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
|
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
|
||||||
@ -255,12 +251,12 @@ def test_flash_attn_varlen_output(
|
|||||||
device = "cuda"
|
device = "cuda"
|
||||||
# set seed
|
# set seed
|
||||||
torch.random.manual_seed(0)
|
torch.random.manual_seed(0)
|
||||||
|
# batch_size = 1
|
||||||
|
# nheads = 1
|
||||||
|
# nheads_kv = 1
|
||||||
batch_size = 9
|
batch_size = 9
|
||||||
nheads = 4
|
nheads = 6
|
||||||
nheads_kv = 4
|
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
|
||||||
# batch_size = 9
|
|
||||||
# nheads = 6
|
|
||||||
# nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
|
|
||||||
|
|
||||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user