[kernel] fix sliding window in prefix prefill Triton kernel (#4405)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
parent
5b8a7c1cb0
commit
32881f3f31
@ -15,6 +15,7 @@ DTYPES = [torch.float16]
|
|||||||
CUDA_DEVICES = [
|
CUDA_DEVICES = [
|
||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
]
|
]
|
||||||
|
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@ -22,11 +23,13 @@ CUDA_DEVICES = [
|
|||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_contexted_kv_attention(
|
def test_contexted_kv_attention(
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
num_queries_per_kv: int,
|
num_queries_per_kv: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
|
sliding_window: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: str,
|
device: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -123,12 +126,32 @@ def test_contexted_kv_attention(
|
|||||||
|
|
||||||
# Warm up the Triton kernel by calling it once before actually measuring
|
# Warm up the Triton kernel by calling it once before actually measuring
|
||||||
# generation time
|
# generation time
|
||||||
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
|
context_attention_fwd(query,
|
||||||
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
|
k,
|
||||||
|
v,
|
||||||
|
output,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_table,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
b_ctx_len,
|
||||||
|
max_input_len,
|
||||||
|
sliding_window=sliding_window)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
|
context_attention_fwd(query,
|
||||||
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
|
k,
|
||||||
|
v,
|
||||||
|
output,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_table,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
b_ctx_len,
|
||||||
|
max_input_len,
|
||||||
|
sliding_window=sliding_window)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
@ -156,6 +179,9 @@ def test_contexted_kv_attention(
|
|||||||
|
|
||||||
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||||
subquery_lens, seq_lens)
|
subquery_lens, seq_lens)
|
||||||
|
if sliding_window > 0:
|
||||||
|
attn_bias = attn_bias.make_local_attention_from_bottomright(
|
||||||
|
sliding_window)
|
||||||
output_ref = xops.memory_efficient_attention_forward(
|
output_ref = xops.memory_efficient_attention_forward(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
|||||||
@ -249,6 +249,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
prefill_meta.context_lens,
|
prefill_meta.context_lens,
|
||||||
prefill_meta.max_subquery_len,
|
prefill_meta.max_subquery_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
|
self.sliding_window[0],
|
||||||
)
|
)
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
# Decoding run.
|
# Decoding run.
|
||||||
|
|||||||
@ -307,6 +307,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
prefill_meta.context_lens,
|
prefill_meta.context_lens,
|
||||||
prefill_meta.max_subquery_len,
|
prefill_meta.max_subquery_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
|
self.sliding_window[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
|
|||||||
@ -246,6 +246,7 @@ class XFormersImpl(AttentionImpl):
|
|||||||
prefill_meta.context_lens,
|
prefill_meta.context_lens,
|
||||||
prefill_meta.max_subquery_len,
|
prefill_meta.max_subquery_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
|
self.sliding_window,
|
||||||
)
|
)
|
||||||
assert output[:num_prefill_tokens].shape == out.shape
|
assert output[:num_prefill_tokens].shape == out.shape
|
||||||
output[:num_prefill_tokens] = out
|
output[:num_prefill_tokens] = out
|
||||||
|
|||||||
@ -172,6 +172,7 @@ class PagedAttention:
|
|||||||
context_lens: torch.Tensor,
|
context_lens: torch.Tensor,
|
||||||
max_subquery_len: int,
|
max_subquery_len: int,
|
||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
sliding_window: Optional[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
context_attention_fwd(
|
context_attention_fwd(
|
||||||
@ -188,6 +189,7 @@ class PagedAttention:
|
|||||||
context_lens,
|
context_lens,
|
||||||
max_subquery_len,
|
max_subquery_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
sliding_window,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@ -50,6 +50,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
BLOCK_DMODEL: tl.constexpr, # head size
|
BLOCK_DMODEL: tl.constexpr, # head size
|
||||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
|
SLIDING_WINDOW: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@ -62,42 +63,53 @@ if triton.__version__ >= "2.1.0":
|
|||||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||||
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
|
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
|
||||||
|
|
||||||
|
# start position inside of the query
|
||||||
|
# generally, N goes over kv, while M goes over query_len
|
||||||
block_start_loc = BLOCK_M * start_m
|
block_start_loc = BLOCK_M * start_m
|
||||||
|
|
||||||
# initialize offsets
|
# initialize offsets
|
||||||
|
# [N]; starts at 0
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
# [D]; starts at 0
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||||
|
# [M]; starts at current position in query
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
# [M,D]
|
||||||
off_q = (
|
off_q = (
|
||||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||||
|
|
||||||
dim_mask = tl.where(
|
dim_mask = tl.where(
|
||||||
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
|
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
|
||||||
|
0).to(tl.int1) # [D]
|
||||||
|
|
||||||
q = tl.load(Q + off_q,
|
q = tl.load(Q + off_q,
|
||||||
mask=dim_mask[None, :] &
|
mask=dim_mask[None, :] &
|
||||||
(offs_m[:, None] < cur_batch_query_len),
|
(offs_m[:, None] < cur_batch_query_len),
|
||||||
other=0.0)
|
other=0.0) # [M,D]
|
||||||
|
|
||||||
# # initialize pointer to m and l
|
# initialize pointer to m and l
|
||||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M]
|
||||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M]
|
||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
|
||||||
|
dtype=tl.float32) # [M,D]
|
||||||
|
|
||||||
|
# compute query against context (no causal mask here)
|
||||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
# -- compute qk ----
|
# -- compute qk ----
|
||||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||||
other=0)
|
other=0) # [N]
|
||||||
|
# [D,N]
|
||||||
off_k = (bn[None, :] * stride_k_cache_bs +
|
off_k = (bn[None, :] * stride_k_cache_bs +
|
||||||
cur_kv_head * stride_k_cache_h +
|
cur_kv_head * stride_k_cache_h +
|
||||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||||
((start_n + offs_n[None, :]) % block_size) *
|
((start_n + offs_n[None, :]) % block_size) *
|
||||||
stride_k_cache_bl +
|
stride_k_cache_bl +
|
||||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||||
|
# [N,D]
|
||||||
off_v = (
|
off_v = (
|
||||||
bn[:, None] * stride_v_cache_bs +
|
bn[:, None] * stride_v_cache_bs +
|
||||||
cur_kv_head * stride_v_cache_h +
|
cur_kv_head * stride_v_cache_h +
|
||||||
@ -106,23 +118,39 @@ if triton.__version__ >= "2.1.0":
|
|||||||
k = tl.load(K_cache + off_k,
|
k = tl.load(K_cache + off_k,
|
||||||
mask=dim_mask[:, None] &
|
mask=dim_mask[:, None] &
|
||||||
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||||
other=0.0)
|
other=0.0) # [D,N]
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||||
float("-inf"))
|
float("-inf"))
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
|
||||||
|
# Q entries in sequence
|
||||||
|
# (start_n + offs_n[None, :]) are the positions of
|
||||||
|
# KV entries in sequence
|
||||||
|
# So the condition makes sure each entry in Q only attends
|
||||||
|
# to KV entries not more than SLIDING_WINDOW away.
|
||||||
|
#
|
||||||
|
# We can't use -inf here, because the
|
||||||
|
# sliding window may lead to the entire row being masked.
|
||||||
|
# This then makes m_ij contain -inf, which causes NaNs in
|
||||||
|
# exp().
|
||||||
|
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
|
||||||
|
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
|
||||||
|
-10000)
|
||||||
|
|
||||||
# -- compute m_ij, p, l_ij
|
# -- compute m_ij, p, l_ij
|
||||||
m_ij = tl.max(qk, 1)
|
m_ij = tl.max(qk, 1) # [M]
|
||||||
p = tl.exp(qk - m_ij[:, None])
|
p = tl.exp(qk - m_ij[:, None]) # [M,N]
|
||||||
l_ij = tl.sum(p, 1)
|
l_ij = tl.sum(p, 1) # [M]
|
||||||
# -- update m_i and l_i
|
# -- update m_i and l_i
|
||||||
m_i_new = tl.maximum(m_i, m_ij)
|
m_i_new = tl.maximum(m_i, m_ij) # [M]
|
||||||
alpha = tl.exp(m_i - m_i_new)
|
alpha = tl.exp(m_i - m_i_new) # [M]
|
||||||
beta = tl.exp(m_ij - m_i_new)
|
beta = tl.exp(m_ij - m_i_new) # [M]
|
||||||
l_i_new = alpha * l_i + beta * l_ij
|
l_i_new = alpha * l_i + beta * l_ij # [M]
|
||||||
|
|
||||||
# -- update output accumulator --
|
# -- update output accumulator --
|
||||||
# scale p
|
# scale p
|
||||||
p_scale = beta / l_i_new
|
p_scale = beta / l_i_new
|
||||||
@ -134,7 +162,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v = tl.load(V_cache + off_v,
|
v = tl.load(V_cache + off_v,
|
||||||
mask=dim_mask[None, :] &
|
mask=dim_mask[None, :] &
|
||||||
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||||
other=0.0)
|
other=0.0) # [N,D]
|
||||||
|
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
acc += tl.dot(p, v)
|
acc += tl.dot(p, v)
|
||||||
@ -149,8 +177,10 @@ if triton.__version__ >= "2.1.0":
|
|||||||
k_ptrs = K + off_k
|
k_ptrs = K + off_k
|
||||||
v_ptrs = V + off_v
|
v_ptrs = V + off_v
|
||||||
|
|
||||||
|
# block_mask is 0 when we're already past the current query length
|
||||||
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
|
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
|
||||||
|
|
||||||
|
# compute query against itself (with causal mask)
|
||||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
# -- compute qk ----
|
# -- compute qk ----
|
||||||
@ -163,8 +193,13 @@ if triton.__version__ >= "2.1.0":
|
|||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
|
# apply causal mask
|
||||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||||
float("-inf"))
|
float("-inf"))
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
qk = tl.where(
|
||||||
|
offs_m[:, None] -
|
||||||
|
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000)
|
||||||
|
|
||||||
# -- compute m_ij, p, l_ij
|
# -- compute m_ij, p, l_ij
|
||||||
m_ij = tl.max(qk, 1)
|
m_ij = tl.max(qk, 1)
|
||||||
@ -636,7 +671,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_ctx_len,
|
b_ctx_len,
|
||||||
max_input_len,
|
max_input_len,
|
||||||
alibi_slopes=None):
|
alibi_slopes=None,
|
||||||
|
sliding_window=None):
|
||||||
|
|
||||||
cap = torch.cuda.get_device_capability()
|
cap = torch.cuda.get_device_capability()
|
||||||
BLOCK = 128 if cap[0] >= 8 else 64
|
BLOCK = 128 if cap[0] >= 8 else 64
|
||||||
@ -644,7 +680,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||||
assert Lq == Lk and Lk == Lv
|
assert Lq == Lk and Lk == Lv
|
||||||
# round up Lk to a power of 2 - this is required for Triton block size
|
# round up Lk to a power of 2 - this is required for Triton block size
|
||||||
Lk_padded = 2**((Lk - 1).bit_length())
|
Lk_padded = triton.next_power_of_2(Lk)
|
||||||
|
|
||||||
sm_scale = 1.0 / (Lq**0.5)
|
sm_scale = 1.0 / (Lq**0.5)
|
||||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||||
@ -749,6 +785,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=Lk,
|
||||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
|
SLIDING_WINDOW=sliding_window if sliding_window is not None else 0,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user